blob: f9ea302e8fdfddffb580bd783f1b99379b7ae64c [file] [log] [blame]
// Copyright 2020 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/ast_parser/function.h"
#include <algorithm>
#include <array>
#include "src/tint/lang/core/builtin_fn.h"
#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/core/type/texture_dimension.h"
#include "src/tint/lang/spirv/reader/ast_lower/atomics.h"
#include "src/tint/lang/wgsl/ast/assignment_statement.h"
#include "src/tint/lang/wgsl/ast/break_statement.h"
#include "src/tint/lang/wgsl/ast/builtin_attribute.h"
#include "src/tint/lang/wgsl/ast/call_statement.h"
#include "src/tint/lang/wgsl/ast/continue_statement.h"
#include "src/tint/lang/wgsl/ast/discard_statement.h"
#include "src/tint/lang/wgsl/ast/if_statement.h"
#include "src/tint/lang/wgsl/ast/loop_statement.h"
#include "src/tint/lang/wgsl/ast/return_statement.h"
#include "src/tint/lang/wgsl/ast/stage_attribute.h"
#include "src/tint/lang/wgsl/ast/switch_statement.h"
#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
#include "src/tint/utils/containers/hashmap.h"
#include "src/tint/utils/containers/hashset.h"
#include "src/tint/utils/rtti/switch.h"
// Terms:
// CFG: the control flow graph of the function, where basic blocks are the
// nodes, and branches form the directed arcs. The function entry block is
// the root of the CFG.
//
// Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
// Then:
// - Let M(H) be the merge block named by the merge instruction in H.
// - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
// CT(H) be the continue target block named by the OpLoopMerge
// instruction.
// - If H is a selection construct whose header ends in
// OpBranchConditional with true target %then and false target %else,
// then TT(H) = %then and FT(H) = %else
//
// Determining output block order:
// The "structured post-order traversal" of the CFG is a post-order traversal
// of the basic blocks in the CFG, where:
// We visit the entry node of the function first.
// When visiting a header block:
// We next visit its merge block
// Then if it's a loop header, we next visit the continue target,
// Then we visit the block's successors (whether it's a header or not)
// If the block ends in an OpBranchConditional, we visit the false target
// before the true target.
//
// The "reverse structured post-order traversal" of the CFG is the reverse
// of the structured post-order traversal.
// This is the order of basic blocks as they should be emitted to the WGSL
// function. It is the order computed by ComputeBlockOrder, and stored in
// the |FunctionEmiter::block_order_|.
// Blocks not in this ordering are ignored by the rest of the algorithm.
//
// Note:
// - A block D in the function might not appear in this order because
// no block in the order branches to D.
// - An unreachable block D might still be in the order because some header
// block in the order names D as its continue target, or merge block,
// or D is reachable from one of those otherwise-unreachable continue
// targets or merge blocks.
//
// Terms:
// Let Pos(B) be the index position of a block B in the computed block order.
//
// CFG intervals and valid nesting:
//
// A correctly structured CFG satisfies nesting rules that we can check by
// comparing positions of related blocks.
//
// If header block H is in the block order, then the following holds:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
//
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
// This gives us the fundamental ordering of blocks in relation to a
// structured construct:
// The blocks before H in the block order, are not in the construct
// The blocks at M(H) or later in the block order, are not in the construct
// The blocks in a selection headed at H are in positions [ Pos(H),
// Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
// [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
// headed at H are in
// positions [ Pos(CT(H)), Pos(M(H)) )
//
// Schematically, for a selection construct headed by H, the blocks are in
// order from left to right:
//
// ...a-b-c H d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the selection construct
// where H and d-e-f: blocks in the selection construct
// where M(H) and n-o-p...: blocks after the selection construct
//
// Schematically, for a loop construct headed by H that is its own
// continue construct, the blocks in order from left to right:
//
// ...a-b-c H=CT(H) d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H is the continue construct; CT(H)=H, and the loop construct
// is *empty*
// where d-e-f... are other blocks in the continue construct
// where M(H) and n-o-p...: blocks after the continue construct
//
// Schematically, for a multi-block loop construct headed by H, there are
// blocks in order from left to right:
//
// ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H and d-e-f: blocks in the loop construct
// where CT(H) and j-k-l: blocks in the continue construct
// where M(H) and n-o-p...: blocks after the loop and continue
// constructs
//
using namespace tint::core::number_suffixes; // NOLINT
using namespace tint::core::fluent_types; // NOLINT
namespace tint::spirv::reader::ast_parser {
namespace {
constexpr uint32_t kMaxVectorLen = 4;
/// @param inst a SPIR-V instruction
/// @returns Returns the opcode for an instruciton
inline spv::Op opcode(const spvtools::opt::Instruction& inst) {
return inst.opcode();
}
/// @param inst a SPIR-V instruction pointer
/// @returns Returns the opcode for an instruciton
inline spv::Op opcode(const spvtools::opt::Instruction* inst) {
return inst->opcode();
}
// Gets the AST unary opcode for the given SPIR-V opcode, if any
// @param opcode SPIR-V opcode
// @param ast_unary_op return parameter
// @returns true if it was a unary operation
bool GetUnaryOp(spv::Op opcode, core::UnaryOp* ast_unary_op) {
switch (opcode) {
case spv::Op::OpSNegate:
case spv::Op::OpFNegate:
*ast_unary_op = core::UnaryOp::kNegation;
return true;
case spv::Op::OpLogicalNot:
*ast_unary_op = core::UnaryOp::kNot;
return true;
case spv::Op::OpNot:
*ast_unary_op = core::UnaryOp::kComplement;
return true;
default:
break;
}
return false;
}
/// Converts a SPIR-V opcode for a WGSL builtin function, if there is a
/// direct translation. Returns nullptr otherwise.
/// @returns the WGSL builtin function name for the given opcode, or nullptr.
const char* GetUnaryBuiltInFunctionName(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpAny:
return "any";
case spv::Op::OpAll:
return "all";
case spv::Op::OpIsNan:
return "isNan";
case spv::Op::OpIsInf:
return "isInf";
case spv::Op::OpTranspose:
return "transpose";
default:
break;
}
return nullptr;
}
// Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
// @param opcode SPIR-V opcode
// @returns the AST binary op for the given opcode, or std::nullopt
std::optional<core::BinaryOp> ConvertBinaryOp(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpIAdd:
case spv::Op::OpFAdd:
return core::BinaryOp::kAdd;
case spv::Op::OpISub:
case spv::Op::OpFSub:
return core::BinaryOp::kSubtract;
case spv::Op::OpIMul:
case spv::Op::OpFMul:
case spv::Op::OpVectorTimesScalar:
case spv::Op::OpMatrixTimesScalar:
case spv::Op::OpVectorTimesMatrix:
case spv::Op::OpMatrixTimesVector:
case spv::Op::OpMatrixTimesMatrix:
return core::BinaryOp::kMultiply;
case spv::Op::OpUDiv:
case spv::Op::OpSDiv:
case spv::Op::OpFDiv:
return core::BinaryOp::kDivide;
case spv::Op::OpUMod:
case spv::Op::OpSMod:
case spv::Op::OpSRem:
case spv::Op::OpFRem:
return core::BinaryOp::kModulo;
case spv::Op::OpLogicalEqual:
case spv::Op::OpIEqual:
case spv::Op::OpFOrdEqual:
return core::BinaryOp::kEqual;
case spv::Op::OpLogicalNotEqual:
case spv::Op::OpINotEqual:
case spv::Op::OpFOrdNotEqual:
return core::BinaryOp::kNotEqual;
case spv::Op::OpBitwiseAnd:
return core::BinaryOp::kAnd;
case spv::Op::OpBitwiseOr:
return core::BinaryOp::kOr;
case spv::Op::OpBitwiseXor:
return core::BinaryOp::kXor;
case spv::Op::OpLogicalAnd:
return core::BinaryOp::kAnd;
case spv::Op::OpLogicalOr:
return core::BinaryOp::kOr;
case spv::Op::OpUGreaterThan:
case spv::Op::OpSGreaterThan:
case spv::Op::OpFOrdGreaterThan:
return core::BinaryOp::kGreaterThan;
case spv::Op::OpUGreaterThanEqual:
case spv::Op::OpSGreaterThanEqual:
case spv::Op::OpFOrdGreaterThanEqual:
return core::BinaryOp::kGreaterThanEqual;
case spv::Op::OpULessThan:
case spv::Op::OpSLessThan:
case spv::Op::OpFOrdLessThan:
return core::BinaryOp::kLessThan;
case spv::Op::OpULessThanEqual:
case spv::Op::OpSLessThanEqual:
case spv::Op::OpFOrdLessThanEqual:
return core::BinaryOp::kLessThanEqual;
default:
break;
}
// It's not clear what OpSMod should map to.
// https://bugs.chromium.org/p/tint/issues/detail?id=52
return std::nullopt;
}
// If the given SPIR-V opcode is a floating point unordered comparison,
// then returns the binary float comparison for which it is the negation.
// Otherwise returns std::nullopt.
// @param opcode SPIR-V opcode
// @returns operation corresponding to negated version of the SPIR-V opcode
std::optional<core::BinaryOp> NegatedFloatCompare(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpFUnordEqual:
return core::BinaryOp::kNotEqual;
case spv::Op::OpFUnordNotEqual:
return core::BinaryOp::kEqual;
case spv::Op::OpFUnordLessThan:
return core::BinaryOp::kGreaterThanEqual;
case spv::Op::OpFUnordLessThanEqual:
return core::BinaryOp::kGreaterThan;
case spv::Op::OpFUnordGreaterThan:
return core::BinaryOp::kLessThanEqual;
case spv::Op::OpFUnordGreaterThanEqual:
return core::BinaryOp::kLessThan;
default:
break;
}
return std::nullopt;
}
// Returns the WGSL standard library function for the given
// GLSL.std.450 extended instruction operation code. Unknown
// and invalid opcodes map to the empty string.
// @returns the WGSL standard function name, or an empty string.
std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
case GLSLstd450FAbs:
case GLSLstd450SAbs:
return "abs";
case GLSLstd450Acos:
return "acos";
case GLSLstd450Asin:
return "asin";
case GLSLstd450Atan:
return "atan";
case GLSLstd450Atan2:
return "atan2";
case GLSLstd450Ceil:
return "ceil";
case GLSLstd450UClamp:
case GLSLstd450SClamp:
case GLSLstd450NClamp:
case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands
return "clamp";
case GLSLstd450Cos:
return "cos";
case GLSLstd450Cosh:
return "cosh";
case GLSLstd450Cross:
return "cross";
case GLSLstd450Degrees:
return "degrees";
case GLSLstd450Determinant:
return "determinant";
case GLSLstd450Distance:
return "distance";
case GLSLstd450Exp:
return "exp";
case GLSLstd450Exp2:
return "exp2";
case GLSLstd450FaceForward:
return "faceForward";
case GLSLstd450FindILsb:
return "firstTrailingBit";
case GLSLstd450FindSMsb:
return "firstLeadingBit";
case GLSLstd450FindUMsb:
return "firstLeadingBit";
case GLSLstd450Floor:
return "floor";
case GLSLstd450Fma:
return "fma";
case GLSLstd450Fract:
return "fract";
case GLSLstd450InverseSqrt:
return "inverseSqrt";
case GLSLstd450Ldexp:
return "ldexp";
case GLSLstd450Length:
return "length";
case GLSLstd450Log:
return "log";
case GLSLstd450Log2:
return "log2";
case GLSLstd450NMax:
case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
case GLSLstd450UMax:
case GLSLstd450SMax:
return "max";
case GLSLstd450NMin:
case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
case GLSLstd450UMin:
case GLSLstd450SMin:
return "min";
case GLSLstd450FMix:
return "mix";
case GLSLstd450Normalize:
return "normalize";
case GLSLstd450PackSnorm4x8:
return "pack4x8snorm";
case GLSLstd450PackUnorm4x8:
return "pack4x8unorm";
case GLSLstd450PackSnorm2x16:
return "pack2x16snorm";
case GLSLstd450PackUnorm2x16:
return "pack2x16unorm";
case GLSLstd450PackHalf2x16:
return "pack2x16float";
case GLSLstd450Pow:
return "pow";
case GLSLstd450FSign:
case GLSLstd450SSign:
return "sign";
case GLSLstd450Radians:
return "radians";
case GLSLstd450Reflect:
return "reflect";
case GLSLstd450Refract:
return "refract";
case GLSLstd450Round:
case GLSLstd450RoundEven:
return "round";
case GLSLstd450Sin:
return "sin";
case GLSLstd450Sinh:
return "sinh";
case GLSLstd450SmoothStep:
return "smoothstep";
case GLSLstd450Sqrt:
return "sqrt";
case GLSLstd450Step:
return "step";
case GLSLstd450Tan:
return "tan";
case GLSLstd450Tanh:
return "tanh";
case GLSLstd450Trunc:
return "trunc";
case GLSLstd450UnpackSnorm4x8:
return "unpack4x8snorm";
case GLSLstd450UnpackUnorm4x8:
return "unpack4x8unorm";
case GLSLstd450UnpackSnorm2x16:
return "unpack2x16snorm";
case GLSLstd450UnpackUnorm2x16:
return "unpack2x16unorm";
case GLSLstd450UnpackHalf2x16:
return "unpack2x16float";
default:
// TODO(dneto) - The following are not implemented.
// They are grouped semantically, as in GLSL.std.450.h.
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450Modf:
case GLSLstd450ModfStruct:
case GLSLstd450IMix:
case GLSLstd450Frexp:
case GLSLstd450FrexpStruct:
case GLSLstd450PackDouble2x32:
case GLSLstd450UnpackDouble2x32:
case GLSLstd450InterpolateAtCentroid:
case GLSLstd450InterpolateAtSample:
case GLSLstd450InterpolateAtOffset:
break;
}
return "";
}
// Returns the WGSL standard library function builtin for the
// given instruction, or wgsl::BuiltinFn::kNone
wgsl::BuiltinFn GetBuiltin(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpBitCount:
return wgsl::BuiltinFn::kCountOneBits;
case spv::Op::OpBitFieldInsert:
return wgsl::BuiltinFn::kInsertBits;
case spv::Op::OpBitFieldSExtract:
case spv::Op::OpBitFieldUExtract:
return wgsl::BuiltinFn::kExtractBits;
case spv::Op::OpBitReverse:
return wgsl::BuiltinFn::kReverseBits;
case spv::Op::OpDot:
return wgsl::BuiltinFn::kDot;
case spv::Op::OpDPdx:
return wgsl::BuiltinFn::kDpdx;
case spv::Op::OpDPdy:
return wgsl::BuiltinFn::kDpdy;
case spv::Op::OpFwidth:
return wgsl::BuiltinFn::kFwidth;
case spv::Op::OpDPdxFine:
return wgsl::BuiltinFn::kDpdxFine;
case spv::Op::OpDPdyFine:
return wgsl::BuiltinFn::kDpdyFine;
case spv::Op::OpFwidthFine:
return wgsl::BuiltinFn::kFwidthFine;
case spv::Op::OpDPdxCoarse:
return wgsl::BuiltinFn::kDpdxCoarse;
case spv::Op::OpDPdyCoarse:
return wgsl::BuiltinFn::kDpdyCoarse;
case spv::Op::OpFwidthCoarse:
return wgsl::BuiltinFn::kFwidthCoarse;
default:
break;
}
return wgsl::BuiltinFn::kNone;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpSampledImage value.
bool IsSampledImageAccess(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpImageSampleImplicitLod:
case spv::Op::OpImageSampleExplicitLod:
case spv::Op::OpImageSampleDrefImplicitLod:
case spv::Op::OpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case spv::Op::OpImageSampleProjImplicitLod:
case spv::Op::OpImageSampleProjExplicitLod:
case spv::Op::OpImageSampleProjDrefImplicitLod:
case spv::Op::OpImageSampleProjDrefExplicitLod:
case spv::Op::OpImageGather:
case spv::Op::OpImageDrefGather:
case spv::Op::OpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an atomic operation.
bool IsAtomicOp(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpAtomicLoad:
case spv::Op::OpAtomicStore:
case spv::Op::OpAtomicExchange:
case spv::Op::OpAtomicCompareExchange:
case spv::Op::OpAtomicCompareExchangeWeak:
case spv::Op::OpAtomicIIncrement:
case spv::Op::OpAtomicIDecrement:
case spv::Op::OpAtomicIAdd:
case spv::Op::OpAtomicISub:
case spv::Op::OpAtomicSMin:
case spv::Op::OpAtomicUMin:
case spv::Op::OpAtomicSMax:
case spv::Op::OpAtomicUMax:
case spv::Op::OpAtomicAnd:
case spv::Op::OpAtomicOr:
case spv::Op::OpAtomicXor:
case spv::Op::OpAtomicFlagTestAndSet:
case spv::Op::OpAtomicFlagClear:
case spv::Op::OpAtomicFMinEXT:
case spv::Op::OpAtomicFMaxEXT:
case spv::Op::OpAtomicFAddEXT:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image sampling, gather,
// or gather-compare operation.
bool IsImageSamplingOrGatherOrDrefGather(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpImageSampleImplicitLod:
case spv::Op::OpImageSampleExplicitLod:
case spv::Op::OpImageSampleDrefImplicitLod:
case spv::Op::OpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case spv::Op::OpImageSampleProjImplicitLod:
case spv::Op::OpImageSampleProjExplicitLod:
case spv::Op::OpImageSampleProjDrefImplicitLod:
case spv::Op::OpImageSampleProjDrefExplicitLod:
case spv::Op::OpImageGather:
case spv::Op::OpImageDrefGather:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpImage value.
bool IsRawImageAccess(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpImageRead:
case spv::Op::OpImageWrite:
case spv::Op::OpImageFetch:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image query instruction
bool IsImageQuery(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpImageQuerySize:
case spv::Op::OpImageQuerySizeLod:
case spv::Op::OpImageQueryLevels:
case spv::Op::OpImageQuerySamples:
case spv::Op::OpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @returns the merge block ID for the given basic block, or 0 if there is none.
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
auto* inst = bb.GetMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
}
// @returns the continue target ID for the given basic block, or 0 if there
// is none.
uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpLoopMerge instruction, if any.
auto* inst = bb.GetLoopMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
}
// A structured traverser produces the reverse structured post-order of the
// CFG of a function. The blocks traversed are the transitive closure (minimum
// fixed point) of:
// - the entry block
// - a block reached by a branch from another block in the set
// - a block mentioned as a merge block or continue target for a block in the
// set
class StructuredTraverser {
public:
explicit StructuredTraverser(const spvtools::opt::Function& function) : function_(function) {
for (auto& block : function_) {
id_to_block_[block.id()] = &block;
}
}
// Returns the reverse postorder traversal of the CFG, where:
// - a merge block always follows its associated constructs
// - a continue target always follows the associated loop construct, if any
// @returns the IDs of blocks in reverse structured post order
std::vector<uint32_t> ReverseStructuredPostOrder() {
visit_order_.Clear();
visited_.clear();
VisitBackward(function_.entry()->id());
std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
return order;
}
private:
// Executes a depth first search of the CFG, where right after we visit a
// header, we will visit its merge block, then its continue target (if any).
// Also records the post order ordering.
void VisitBackward(uint32_t id) {
if (id == 0) {
return;
}
if (visited_.count(id)) {
return;
}
visited_.insert(id);
const spvtools::opt::BasicBlock* bb = id_to_block_[id]; // non-null for valid modules
VisitBackward(MergeFor(*bb));
VisitBackward(ContinueTargetFor(*bb));
// Visit successors. We will naturally skip the continue target and merge
// blocks.
auto* terminator = bb->terminator();
const auto opcode = terminator->opcode();
if (opcode == spv::Op::OpBranchConditional) {
// Visit the false branch, then the true branch, to make them come
// out in the natural order for an "if".
VisitBackward(terminator->GetSingleWordInOperand(2));
VisitBackward(terminator->GetSingleWordInOperand(1));
} else if (opcode == spv::Op::OpBranch) {
VisitBackward(terminator->GetSingleWordInOperand(0));
} else if (opcode == spv::Op::OpSwitch) {
// TODO(dneto): Consider visiting the labels in literal-value order.
tint::Vector<uint32_t, 32> successors;
bb->ForEachSuccessorLabel(
[&successors](const uint32_t succ_id) { successors.Push(succ_id); });
for (auto succ_id : successors) {
VisitBackward(succ_id);
}
}
visit_order_.Push(id);
}
const spvtools::opt::Function& function_;
std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
tint::Vector<uint32_t, 32> visit_order_;
std::unordered_set<uint32_t> visited_;
};
/// A StatementBuilder for ast::SwitchStatement
/// @see StatementBuilder
struct SwitchStatementBuilder final : public Castable<SwitchStatementBuilder, StatementBuilder> {
/// Constructor
/// @param cond the switch statement condition
explicit SwitchStatementBuilder(const ast::Expression* cond) : condition(cond) {}
/// @param builder the program builder
/// @returns the built ast::SwitchStatement
const ast::SwitchStatement* Build(ProgramBuilder* builder) const override {
// We've listed cases in reverse order in the switch statement.
// Reorder them to match the presentation order in WGSL.
auto reversed_cases = cases;
std::reverse(reversed_cases.begin(), reversed_cases.end());
return builder->Switch(Source{}, condition, std::move(reversed_cases));
}
/// Switch statement condition
const ast::Expression* const condition;
/// Switch statement cases
tint::Vector<ast::CaseStatement*, 4> cases;
};
/// A StatementBuilder for ast::IfStatement
/// @see StatementBuilder
struct IfStatementBuilder final : public Castable<IfStatementBuilder, StatementBuilder> {
/// Constructor
/// @param c the if-statement condition
explicit IfStatementBuilder(const ast::Expression* c) : cond(c) {}
/// @param builder the program builder
/// @returns the built ast::IfStatement
const ast::IfStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::IfStatement>(Source{}, cond, body, else_stmt, tint::Empty);
}
/// If-statement condition
const ast::Expression* const cond;
/// If-statement block body
const ast::BlockStatement* body = nullptr;
/// Optional if-statement else statement
const ast::Statement* else_stmt = nullptr;
};
/// A StatementBuilder for ast::LoopStatement
/// @see StatementBuilder
struct LoopStatementBuilder final : public Castable<LoopStatementBuilder, StatementBuilder> {
/// @param builder the program builder
/// @returns the built ast::LoopStatement
ast::LoopStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::LoopStatement>(Source{}, body, continuing, tint::Empty);
}
/// Loop-statement block body
const ast::BlockStatement* body = nullptr;
/// Loop-statement continuing body
/// @note the mutable keyword here is required as all non-StatementBuilders
/// `ast::Node`s are immutable and are referenced with `const` pointers.
/// StatementBuilders however exist to provide mutable state while the
/// FunctionEmitter is building the function. All StatementBuilders are
/// replaced with immutable AST nodes when Finalize() is called.
mutable const ast::BlockStatement* continuing = nullptr;
};
} // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) : basic_block(&bb), id(bb.id()) {}
BlockInfo::~BlockInfo() = default;
DefInfo::DefInfo(size_t the_index,
const spvtools::opt::Instruction& def_inst,
uint32_t the_block_pos)
: index(the_index), inst(def_inst), local(DefInfo::Local(the_block_pos)) {}
DefInfo::DefInfo(size_t the_index, const spvtools::opt::Instruction& def_inst)
: index(the_index), inst(def_inst) {}
DefInfo::~DefInfo() = default;
DefInfo::Local::Local(uint32_t the_block_pos) : block_pos(the_block_pos) {}
DefInfo::Local::Local(const Local& other) = default;
DefInfo::Local::~Local() = default;
ast::Node* StatementBuilder::Clone(ast::CloneContext&) const {
return nullptr;
}
FunctionEmitter::FunctionEmitter(ASTParser* pi,
const spvtools::opt::Function& function,
const EntryPointInfo* ep_info)
: parser_impl_(*pi),
ty_(pi->type_manager()),
builder_(pi->builder()),
ir_context_(*(pi->ir_context())),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(pi->fail_stream()),
namer_(pi->namer()),
function_(function),
sample_mask_in_id(0u),
sample_mask_out_id(0u),
ep_info_(ep_info) {
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::FunctionEmitter(ASTParser* pi, const spvtools::opt::Function& function)
: FunctionEmitter(pi, function, nullptr) {}
FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
: parser_impl_(other.parser_impl_),
ty_(other.ty_),
builder_(other.builder_),
ir_context_(other.ir_context_),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(other.fail_stream_),
namer_(other.namer_),
function_(other.function_),
sample_mask_in_id(other.sample_mask_out_id),
sample_mask_out_id(other.sample_mask_in_id),
ep_info_(other.ep_info_) {
other.statements_stack_.Clear();
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::~FunctionEmitter() = default;
FunctionEmitter::StatementBlock::StatementBlock(const Construct* construct,
uint32_t end_id,
FunctionEmitter::CompletionAction completion_action)
: construct_(construct), end_id_(end_id), completion_action_(completion_action) {}
FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other) = default;
FunctionEmitter::StatementBlock::~StatementBlock() = default;
void FunctionEmitter::StatementBlock::Finalize(ProgramBuilder* pb) {
TINT_ASSERT(!finalized_ /* Finalize() must only be called once */);
for (size_t i = 0; i < statements_.Length(); i++) {
if (auto* sb = statements_[i]->As<StatementBuilder>()) {
statements_[i] = sb->Build(pb);
}
}
if (completion_action_ != nullptr) {
completion_action_(statements_);
}
finalized_ = true;
}
void FunctionEmitter::StatementBlock::Add(const ast::Statement* statement) {
TINT_ASSERT(!finalized_ /* Add() must not be called after Finalize() */);
statements_.Push(statement);
}
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id,
CompletionAction action) {
statements_stack_.Push(StatementBlock{construct, end_id, action});
}
void FunctionEmitter::PushGuard(const std::string& guard_name, uint32_t end_id) {
TINT_ASSERT(!statements_stack_.IsEmpty());
TINT_ASSERT(!guard_name.empty());
// Guard control flow by the guard variable. Introduce a new
// if-selection with a then-clause ending at the same block
// as the statement block at the top of the stack.
const auto& top = statements_stack_.Back();
auto* cond = builder_.Expr(Source{}, guard_name);
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts, tint::Empty);
});
}
void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
TINT_ASSERT(!statements_stack_.IsEmpty());
const auto& top = statements_stack_.Back();
auto* cond = MakeTrue(Source{});
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts, tint::Empty);
});
}
FunctionEmitter::StatementList FunctionEmitter::ast_body() {
TINT_ASSERT(!statements_stack_.IsEmpty());
auto& entry = statements_stack_[0];
entry.Finalize(&builder_);
return entry.GetStatements();
}
const ast::Statement* FunctionEmitter::AddStatement(const ast::Statement* statement) {
TINT_ASSERT(!statements_stack_.IsEmpty());
if (statement != nullptr) {
statements_stack_.Back().Add(statement);
}
return statement;
}
const ast::Statement* FunctionEmitter::LastStatement() {
TINT_ASSERT(!statements_stack_.IsEmpty());
auto& statement_list = statements_stack_.Back().GetStatements();
TINT_ASSERT(!statement_list.IsEmpty());
return statement_list.Back();
}
bool FunctionEmitter::Emit() {
if (failed()) {
return false;
}
// We only care about functions with bodies.
if (function_.cbegin() == function_.cend()) {
return true;
}
// The function declaration, corresponding to how it's written in SPIR-V,
// and without regard to whether it's an entry point.
FunctionDeclaration decl;
if (!ParseFunctionDeclaration(&decl)) {
return false;
}
bool make_body_function = true;
if (ep_info_) {
TINT_ASSERT(!ep_info_->inner_name.empty());
if (ep_info_->owns_inner_implementation) {
// This is an entry point, and we want to emit it as a wrapper around
// an implementation function.
decl.name = ep_info_->inner_name;
} else {
// This is a second entry point that shares an inner implementation
// function.
make_body_function = false;
}
}
if (make_body_function) {
auto* body = MakeFunctionBody();
if (!body) {
return false;
}
builder_.Func(decl.source, decl.name, std::move(decl.params),
decl.return_type->Build(builder_), body, std::move(decl.attributes.list));
}
if (ep_info_ && !ep_info_->inner_name.empty()) {
return EmitEntryPointAsWrapper();
}
return success();
}
const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
TINT_ASSERT(statements_stack_.Length() == 1);
if (!EmitBody()) {
return nullptr;
}
// Set the body of the AST function node.
if (statements_stack_.Length() != 1) {
Fail() << "internal error: statement-list stack should have 1 "
"element but has "
<< statements_stack_.Length();
return nullptr;
}
statements_stack_[0].Finalize(&builder_);
auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements, tint::Empty);
// Maintain the invariant by repopulating the one and only element.
statements_stack_.Clear();
PushNewStatementBlock(constructs_[0].get(), 0, nullptr);
return body;
}
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type,
tint::Vector<int, 8> index_prefix,
const Type* tip_type,
const Type* forced_param_type,
Attributes& attrs,
ParameterList& params,
StatementList& statements) {
// TODO(dneto): Handle structs where the locations are annotated on members.
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.Push(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.Back() = col;
if (!EmitPipelineInput(var_name, var_type, index_prefix, vec_ty, forced_param_type,
attrs, params, statements)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.Push(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.Back() = i;
if (!EmitPipelineInput(var_name, var_type, index_prefix, elem_ty, forced_param_type,
attrs, params, statements)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.Push(0);
for (size_t i = 0; i < members.size(); ++i) {
index_prefix.Back() = static_cast<int>(i);
Attributes member_attrs(attrs);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type,
parser_impl_.GetMemberPipelineDecorations(*struct_type,
static_cast<int>(i)),
member_attrs)) {
return false;
}
if (!EmitPipelineInput(var_name, var_type, index_prefix, members[i],
forced_param_type, member_attrs, params, statements)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, member_attrs.Get<ast::LocationAttribute>());
}
return success();
},
[&](Default) {
const bool is_builtin = attrs.Has<ast::BuiltinAttribute>();
const Type* param_type = is_builtin ? forced_param_type : tip_type;
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations, then those
// decoration AST nodes will be reused between multiple elements of a matrix, array, or
// structure. Normally that's disallowed but currently the SPIR-V reader will make
// duplicates when the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
params.Push(builder_.Param(param_name, param_type->Build(builder_), attrs.list));
// Add a body statement to copy the parameter to the corresponding
// private variable.
const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name);
// Index into the LHS as needed.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor(
store_dest, parser_impl_.GetMemberName(*struct_type, index));
current_type = struct_type->members[static_cast<size_t>(index)];
});
}
if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to the variable store
// type.
param_value = builder_.Bitcast(tip_type->Build(builder_), param_value);
}
statements.Push(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will follow.
IncrementLocation(attrs);
return success();
});
}
void FunctionEmitter::IncrementLocation(Attributes& attributes) {
for (auto*& attr : attributes.list) {
if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
// Replace this location attribute with a new one with one higher index.
// The old one doesn't leak because it's kept in the builder's AST node list.
attr = builder_.Location(
loc_attr->source, AInt(loc_attr->expr->As<ast::IntLiteralExpression>()->value + 1));
}
}
}
bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const Type* var_type,
tint::Vector<int, 8> index_prefix,
const Type* tip_type,
const Type* forced_member_type,
Attributes& attrs,
StructMemberList& return_members,
ExpressionList& return_exprs) {
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) {
index_prefix.Push(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.Back() = col;
if (!EmitPipelineOutput(var_name, var_type, index_prefix, vec_ty,
forced_member_type, attrs, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.Push(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.Back() = i;
if (!EmitPipelineOutput(var_name, var_type, index_prefix, elem_ty,
forced_member_type, attrs, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.Push(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.Back() = i;
Attributes member_attrs(attrs);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type, parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
member_attrs)) {
return false;
}
if (!EmitPipelineOutput(var_name, var_type, index_prefix,
members[static_cast<size_t>(i)], forced_member_type,
member_attrs, return_members, return_exprs)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, member_attrs.Get<ast::LocationAttribute>());
}
return success();
},
[&](Default) {
const bool is_builtin = attrs.Has<ast::BuiltinAttribute>();
const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations, then those
// decoration AST nodes will be reused between multiple elements of a matrix, array, or
// structure. Normally that's disallowed but currently the SPIR-V reader will make
// duplicates when the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
return_members.Push(
builder_.Member(member_name, member_type->Build(builder_), attrs.list));
// Create an expression to evaluate the part of the variable indexed by
// the index_prefix.
const ast::Expression* load_source = builder_.Expr(var_name);
// Index into the variable as needed to pick out the flattened member.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
load_source = builder_.IndexAccessor(load_source, i32(index));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
load_source = builder_.IndexAccessor(load_source, i32(index));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor(
load_source, parser_impl_.GetMemberName(*struct_type, index));
current_type = struct_type->members[static_cast<size_t>(index)];
});
}
if (is_builtin && (tip_type != forced_member_type)) {
// The member will have the WGSL type, but we need bitcast to
// the variable store type.
load_source = builder_.Bitcast(forced_member_type->Build(builder_), load_source);
}
return_exprs.Push(load_source);
// Increment the location attribute, in case more parameters will follow.
IncrementLocation(attrs);
return success();
});
}
bool FunctionEmitter::EmitEntryPointAsWrapper() {
Source source;
// The statements in the body.
tint::Vector<const ast::Statement*, 8> stmts;
FunctionDeclaration decl;
decl.source = source;
decl.name = ep_info_->name;
ast::Type return_type; // Populated below.
// Pipeline inputs become parameters to the wrapper function, and
// their values are saved into the corresponding private variables that
// have already been created.
for (uint32_t var_id : ep_info_->inputs) {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(var != nullptr);
TINT_ASSERT(opcode(var) == spv::Op::OpVariable);
auto* store_type = GetVariableStoreType(*var);
auto* forced_param_type = store_type;
Attributes param_attrs;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_param_type, param_attrs,
true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
// We don't have to handle initializers because in Vulkan SPIR-V, Input
// variables must not have them.
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (param_attrs.flags.Contains(Attributes::Flags::kHasBuiltinSampleMask)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
// Use the first element only.
auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(sample_mask_array_type);
ok = EmitPipelineInput(var_name, store_type, {0}, sample_mask_array_type->type,
forced_param_type, param_attrs, decl.params, stmts);
} else {
// The normal path.
ok = EmitPipelineInput(var_name, store_type, {}, store_type, forced_param_type,
param_attrs, decl.params, stmts);
}
if (!ok) {
return false;
}
}
// Call the inner function. It has no parameters.
stmts.Push(builder_.CallStmt(source, builder_.Call(source, ep_info_->inner_name)));
// Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.IsEmpty()) {
// There is nothing to return.
return_type = ty_.Void()->Build(builder_);
} else {
// Pipeline outputs are converted to a structure that is written
// to just before returning.
const auto return_struct_name = namer_.MakeDerivedName(ep_info_->name + "_out");
const auto return_struct_sym = builder_.Symbols().Register(return_struct_name);
// Define the structure.
StructMemberList return_members;
ExpressionList return_exprs;
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
for (uint32_t var_id : ep_info_->outputs) {
if (var_id == builtin_position_info.per_vertex_var_id) {
// The SPIR-V gl_PerVertex variable has already been remapped to
// a gl_Position variable. Substitute the type.
const Type* param_type = ty_.Vector(ty_.F32(), 4);
const auto var_name = namer_.GetName(var_id);
return_members.Push(
builder_.Member(var_name, param_type->Build(builder_),
tint::Vector{
builder_.Builtin(source, core::BuiltinValue::kPosition),
}));
return_exprs.Push(builder_.Expr(var_name));
} else {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(var != nullptr);
TINT_ASSERT(opcode(var) == spv::Op::OpVariable);
const Type* store_type = GetVariableStoreType(*var);
const Type* forced_member_type = store_type;
Attributes out_attrs;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_member_type,
out_attrs, true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (out_attrs.flags.Contains(Attributes::Flags::kHasBuiltinSampleMask)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a
// scalar. Use the first element only.
auto* sample_mask_array_type =
store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(sample_mask_array_type);
ok = EmitPipelineOutput(var_name, store_type, {0}, sample_mask_array_type->type,
forced_member_type, out_attrs, return_members,
return_exprs);
} else {
// The normal path.
ok =
EmitPipelineOutput(var_name, store_type, {}, store_type, forced_member_type,
out_attrs, return_members, return_exprs);
}
if (!ok) {
return false;
}
}
}
if (return_members.IsEmpty()) {
// This can occur if only the PointSize member is accessed, because we
// never emit it.
return_type = ty_.Void()->Build(builder_);
} else {
// Create and register the result type.
auto* str = create<ast::Struct>(Source{}, builder_.Ident(return_struct_sym),
return_members, tint::Empty);
parser_impl_.AddTypeDecl(return_struct_sym, str);
return_type = builder_.ty.Of(str);
// Add the return-value statement.
stmts.Push(builder_.Return(
source, builder_.Call(source, return_type, std::move(return_exprs))));
}
}
tint::Vector<const ast::Attribute*, 2> fn_attrs{
create<ast::StageAttribute>(source, ep_info_->stage),
};
if (ep_info_->stage == ast::PipelineStage::kCompute) {
auto& size = ep_info_->workgroup_size;
if (size.x != 0 && size.y != 0 && size.z != 0) {
const ast::Expression* x = builder_.Expr(i32(size.x));
const ast::Expression* y = size.y ? builder_.Expr(i32(size.y)) : nullptr;
const ast::Expression* z = size.z ? builder_.Expr(i32(size.z)) : nullptr;
fn_attrs.Push(create<ast::WorkgroupAttribute>(Source{}, x, y, z));
}
}
builder_.Func(source, ep_info_->name, std::move(decl.params), return_type, std::move(stmts),
std::move(fn_attrs));
return true;
}
bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
if (failed()) {
return false;
}
const std::string name = namer_.Name(function_.result_id());
// Surprisingly, the "type id" on an OpFunction is the result type of the
// function, not the type of the function. This is the one exceptional case
// in SPIR-V where the type ID is not the type of the result ID.
auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
if (failed()) {
return false;
}
if (ret_ty == nullptr) {
return Fail() << "internal error: unregistered return type for function with ID "
<< function_.result_id();
}
ParameterList ast_params;
function_.ForEachParam([this, &ast_params](const spvtools::opt::Instruction* param) {
// Valid SPIR-V requires function call parameters to be non-null
// instructions.
TINT_ASSERT(param != nullptr);
const Type* const type = IsHandleObj(*param)
? parser_impl_.GetHandleTypeForSpirvHandle(*param)
: parser_impl_.ConvertType(param->type_id());
if (type != nullptr) {
auto* ast_param = parser_impl_.MakeParameter(param->result_id(), type, Attributes{});
// Parameters are treated as const declarations.
ast_params.Push(ast_param);
// The value is accessible by name.
identifier_types_.emplace(param->result_id(), type);
} else {
// We've already logged an error and emitted a diagnostic. Do nothing
// here.
}
});
if (failed()) {
return false;
}
decl->name = name;
decl->params = std::move(ast_params);
decl->return_type = ret_ty;
decl->attributes = {};
return success();
}
bool FunctionEmitter::IsHandleObj(const spvtools::opt::Instruction& obj) {
TINT_ASSERT(obj.type_id() != 0u);
auto* spirv_type = type_mgr_->GetType(obj.type_id());
TINT_ASSERT(spirv_type);
return spirv_type->AsImage() || spirv_type->AsSampler() ||
(spirv_type->AsPointer() &&
(static_cast<spv::StorageClass>(spirv_type->AsPointer()->storage_class()) ==
spv::StorageClass::UniformConstant));
}
bool FunctionEmitter::IsHandleObj(const spvtools::opt::Instruction* obj) {
return (obj != nullptr) && IsHandleObj(*obj);
}
const Type* FunctionEmitter::GetVariableStoreType(const spvtools::opt::Instruction& var_decl_inst) {
const auto type_id = var_decl_inst.type_id();
// Normally we use the SPIRV-Tools optimizer to manage types.
// But when two struct types have the same member types and decorations,
// but differ only in member names, the two struct types will be
// represented by a single common internal struct type.
// So avoid the optimizer's representation and instead follow the
// SPIR-V instructions themselves.
const auto* ptr_ty = def_use_mgr_->GetDef(type_id);
const auto store_ty_id = ptr_ty->GetSingleWordInOperand(1);
const auto* result = parser_impl_.ConvertType(store_ty_id);
return result;
}
bool FunctionEmitter::EmitBody() {
RegisterBasicBlocks();
if (!TerminatorsAreValid()) {
return false;
}
if (!RegisterMerges()) {
return false;
}
ComputeBlockOrderAndPositions();
if (!VerifyHeaderContinueMergeOrder()) {
return false;
}
if (!LabelControlFlowConstructs()) {
return false;
}
if (!FindSwitchCaseHeaders()) {
return false;
}
if (!ClassifyCFGEdges()) {
return false;
}
if (!FindIfSelectionInternalHeaders()) {
return false;
}
if (!RegisterSpecialBuiltInVariables()) {
return false;
}
if (!RegisterLocallyDefinedValues()) {
return false;
}
FindValuesNeedingNamedOrHoistedDefinition();
if (!EmitFunctionVariables()) {
return false;
}
if (!EmitFunctionBodyStatements()) {
return false;
}
return success();
}
void FunctionEmitter::RegisterBasicBlocks() {
for (auto& block : function_) {
block_info_[block.id()] = std::make_unique<BlockInfo>(block);
}
}
bool FunctionEmitter::TerminatorsAreValid() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
if (!block.terminator()) {
return Fail() << "Block " << block.id() << " has no terminator";
}
}
for (const auto& block : function_) {
block.WhileEachSuccessorLabel([this, &block, entry_id](const uint32_t succ_id) -> bool {
if (succ_id == entry_id) {
return Fail() << "Block " << block.id() << " branches to function entry block "
<< entry_id;
}
if (!GetBlockInfo(succ_id)) {
return Fail() << "Block " << block.id() << " in function "
<< function_.DefInst().result_id() << " branches to " << succ_id
<< " which is not a block in the function";
}
return true;
});
}
return success();
}
bool FunctionEmitter::RegisterMerges() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
const auto block_id = block.id();
auto* block_info = GetBlockInfo(block_id);
if (!block_info) {
return Fail() << "internal error: block " << block_id
<< " missing; blocks should already "
"have been registered";
}
if (const auto* inst = block.GetMergeInst()) {
auto terminator_opcode = opcode(block.terminator());
switch (opcode(inst)) {
case spv::Op::OpSelectionMerge:
if ((terminator_opcode != spv::Op::OpBranchConditional) &&
(terminator_opcode != spv::Op::OpSwitch)) {
return Fail() << "Selection header " << block_id
<< " does not end in an OpBranchConditional or "
"OpSwitch instruction";
}
break;
case spv::Op::OpLoopMerge:
if ((terminator_opcode != spv::Op::OpBranchConditional) &&
(terminator_opcode != spv::Op::OpBranch)) {
return Fail() << "Loop header " << block_id
<< " does not end in an OpBranch or "
"OpBranchConditional instruction";
}
break;
default:
break;
}
const uint32_t header = block.id();
auto* header_info = block_info;
const uint32_t merge = inst->GetSingleWordInOperand(0);
auto* merge_info = GetBlockInfo(merge);
if (!merge_info) {
return Fail() << "Structured header block " << header
<< " declares invalid merge block " << merge;
}
if (merge == header) {
return Fail() << "Structured header block " << header
<< " cannot be its own merge block";
}
if (merge_info->header_for_merge) {
return Fail() << "Block " << merge
<< " declared as merge block for more than one header: "
<< merge_info->header_for_merge << ", " << header;
}
merge_info->header_for_merge = header;
header_info->merge_for_header = merge;
if (opcode(inst) == spv::Op::OpLoopMerge) {
if (header == entry_id) {
return Fail() << "Function entry block " << entry_id
<< " cannot be a loop header";
}
const uint32_t ct = inst->GetSingleWordInOperand(1);
auto* ct_info = GetBlockInfo(ct);
if (!ct_info) {
return Fail() << "Structured header " << header
<< " declares invalid continue target " << ct;
}
if (ct == merge) {
return Fail() << "Invalid structured header block " << header
<< ": declares block " << ct
<< " as both its merge block and continue target";
}
if (ct_info->header_for_continue) {
return Fail() << "Block " << ct
<< " declared as continue target for more than one header: "
<< ct_info->header_for_continue << ", " << header;
}
ct_info->header_for_continue = header;
header_info->continue_for_header = ct;
}
}
// Check single-block loop cases.
bool is_single_block_loop = false;
block_info->basic_block->ForEachSuccessorLabel(
[&is_single_block_loop, block_id](const uint32_t succ) {
if (block_id == succ) {
is_single_block_loop = true;
}
});
const auto ct = block_info->continue_for_header;
block_info->is_continue_entire_loop = ct == block_id;
if (is_single_block_loop && !block_info->is_continue_entire_loop) {
return Fail() << "Block " << block_id
<< " branches to itself but is not its own continue target";
}
// It's valid for a the header of a multi-block loop header to declare
// itself as its own continue target.
}
return success();
}
void FunctionEmitter::ComputeBlockOrderAndPositions() {
block_order_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
for (uint32_t i = 0; i < block_order_.size(); ++i) {
GetBlockInfo(block_order_[i])->pos = i;
}
// The invalid block position is not the position of any block that is in the
// order.
assert(block_order_.size() <= kInvalidBlockPos);
}
bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
// Verify interval rules for a structured header block:
//
// If the CFG satisfies structured control flow rules, then:
// If header H is reachable, then the following "interval rules" hold,
// where M(H) is H's merge block, and CT(H) is H's continue target:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto merge = block_info->merge_for_header;
if (merge == 0) {
continue;
}
// This is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto header_pos = header_info->pos;
const auto merge_pos = GetBlockInfo(merge)->pos;
// Pos(H) < Pos(M(H))
// Note: When recording merges we made sure H != M(H)
if (merge_pos <= header_pos) {
return Fail() << "Header " << header << " does not strictly dominate its merge block "
<< merge;
// TODO(dneto): Report a path from the entry block to the merge block
// without going through the header block.
}
const auto ct = block_info->continue_for_header;
if (ct == 0) {
continue;
}
// Furthermore, this is a loop header.
const auto* ct_info = GetBlockInfo(ct);
const auto ct_pos = ct_info->pos;
// Pos(H) <= Pos(CT(H))
if (ct_pos < header_pos) {
Fail() << "Loop header " << header << " does not dominate its continue target " << ct;
}
// Pos(CT(H)) < Pos(M(H))
// Note: When recording merges we made sure CT(H) != M(H)
if (merge_pos <= ct_pos) {
return Fail() << "Merge block " << merge << " for loop headed at block " << header
<< " appears at or before the loop's continue "
"construct headed by "
"block "
<< ct;
}
}
return success();
}
bool FunctionEmitter::LabelControlFlowConstructs() {
// Label each block in the block order with its nearest enclosing structured
// control flow construct. Populates the |construct| member of BlockInfo.
// Keep a stack of enclosing structured control flow constructs. Start
// with the synthetic construct representing the entire function.
//
// Scan from left to right in the block order, and check conditions
// on each block in the following order:
//
// a. When you reach a merge block, the top of the stack should
// be the associated header. Pop it off.
// b. When you reach a header, push it on the stack.
// c. When you reach a continue target, push it on the stack.
// (A block can be both a header and a continue target.)
// c. When you reach a block with an edge branching backward (in the
// structured order) to block T:
// T should be a loop header, and the top of the stack should be a
// continue target associated with T.
// This is the end of the continue construct. Pop the continue
// target off the stack.
//
// Note: A loop header can declare itself as its own continue target.
//
// Note: For a single-block loop, that block is a header, its own
// continue target, and its own backedge block.
//
// Note: We pop the merge off first because a merge block that marks
// the end of one construct can be a single-block loop. So that block
// is a merge, a header, a continue target, and a backedge block.
// But we want to finish processing of the merge before dealing with
// the loop.
//
// In the same scan, mark each basic block with the nearest enclosing
// header: the most recent header for which we haven't reached its merge
// block. Also mark the the most recent continue target for which we
// haven't reached the backedge block.
TINT_ASSERT(block_order_.size() > 0);
constructs_.Clear();
const auto entry_id = block_order_[0];
// The stack of enclosing constructs.
tint::Vector<Construct*, 4> enclosing;
// Creates a control flow construct and pushes it onto the stack.
// Its parent is the top of the stack, or nullptr if the stack is empty.
// Returns the newly created construct.
auto push_construct = [this, &enclosing](size_t depth, Construct::Kind k, uint32_t begin_id,
uint32_t end_id) -> Construct* {
const auto begin_pos = GetBlockInfo(begin_id)->pos;
const auto end_pos =
end_id == 0 ? uint32_t(block_order_.size()) : GetBlockInfo(end_id)->pos;
const auto* parent = enclosing.IsEmpty() ? nullptr : enclosing.Back();
auto scope_end_pos = end_pos;
// A loop construct is added right after its associated continue construct.
// In that case, adjust the parent up.
if (k == Construct::kLoop) {
TINT_ASSERT(parent);
TINT_ASSERT(parent->kind == Construct::kContinue);
scope_end_pos = parent->end_pos;
parent = parent->parent;
}
constructs_.Push(std::make_unique<Construct>(parent, static_cast<int>(depth), k, begin_id,
end_id, begin_pos, end_pos, scope_end_pos));
Construct* result = constructs_.Back().get();
enclosing.Push(result);
return result;
};
// Make a synthetic kFunction construct to enclose all blocks in the function.
push_construct(0, Construct::kFunction, entry_id, 0);
// The entry block can be a selection construct, so be sure to process
// it anyway.
for (uint32_t i = 0; i < block_order_.size(); ++i) {
const auto block_id = block_order_[i];
TINT_ASSERT(block_id > 0);
auto* block_info = GetBlockInfo(block_id);
TINT_ASSERT(block_info);
if (enclosing.IsEmpty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
const Construct* top = enclosing.Back();
while (block_id == top->end_id) {
// We've reached a predeclared end of the construct. Pop it off the
// stack.
enclosing.Pop();
if (enclosing.IsEmpty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
top = enclosing.Back();
}
const auto merge = block_info->merge_for_header;
if (merge != 0) {
// The current block is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto depth = static_cast<size_t>(1 + top->depth);
const auto ct = header_info->continue_for_header;
if (ct != 0) {
// The current block is a loop header.
// We should see the continue construct after the loop construct, so
// push the loop construct last.
// From the interval rule, the continue construct consists of blocks
// in the block order, starting at the continue target, until just
// before the merge block.
top = push_construct(depth, Construct::kContinue, ct, merge);
// A loop header that is its own continue target will have an
// empty loop construct. Only create a loop construct when
// the continue target is *not* the same as the loop header.
if (header != ct) {
// From the interval rule, the loop construct consists of blocks
// in the block order, starting at the header, until just
// before the continue target.
top = push_construct(depth, Construct::kLoop, header, ct);
// If the loop header branches to two different blocks inside the loop
// construct, then the loop body should be modeled as an if-selection
// construct
tint::Vector<uint32_t, 4> targets;
header_info->basic_block->ForEachSuccessorLabel(
[&targets](const uint32_t target) { targets.Push(target); });
if ((targets.Length() == 2u) && targets[0] != targets[1]) {
const auto target0_pos = GetBlockInfo(targets[0])->pos;
const auto target1_pos = GetBlockInfo(targets[1])->pos;
if (top->ContainsPos(target0_pos) && top->ContainsPos(target1_pos)) {
// Insert a synthetic if-selection
top = push_construct(depth + 1, Construct::kIfSelection, header, ct);
}
}
}
} else {
// From the interval rule, the selection construct consists of blocks
// in the block order, starting at the header, until just before the
// merge block.
const auto branch_opcode = opcode(header_info->basic_block->terminator());
const auto kind = (branch_opcode == spv::Op::OpBranchConditional)
? Construct::kIfSelection
: Construct::kSwitchSelection;
top = push_construct(depth, kind, header, merge);
}
}
TINT_ASSERT(top);
block_info->construct = top;
}
// At the end of the block list, we should only have the kFunction construct
// left.
if (enclosing.Length() != 1) {
return Fail() << "internal error: unbalanced structured constructs when "
"labeling structured constructs: ended with "
<< enclosing.Length() - 1 << " unterminated constructs";
}
const auto* top = enclosing[0];
if (top->kind != Construct::kFunction || top->depth != 0) {
return Fail() << "internal error: outermost construct is not a function?!";
}
return success();
}
bool FunctionEmitter::FindSwitchCaseHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kSwitchSelection) {
continue;
}
const auto* branch = GetBlockInfo(construct->begin_id)->basic_block->terminator();
// Mark the default block
const auto default_id = branch->GetSingleWordInOperand(1);
auto* default_block = GetBlockInfo(default_id);
// A default target can't be a backedge.
if (construct->begin_pos >= default_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't precede an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default target block " << default_id << " can't be a back-edge";
}
// A default target can be the merge block, but can't go past it.
if (construct->end_pos < default_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default block " << default_id
<< " escapes the selection construct";
}
if (default_block->default_head_for) {
// An OpSwitch must dominate its cases, including the default target.
return Fail() << "Block " << default_id
<< " is declared as the default target for two OpSwitch "
"instructions, at blocks "
<< default_block->default_head_for->begin_id << " and "
<< construct->begin_id;
}
if ((default_block->header_for_merge != 0) &&
(default_block->header_for_merge != construct->begin_id)) {
// The switch instruction for this default block is an alternate path to
// the merge block, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << default_block->id
<< " is the default block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< default_block->header_for_merge << " (violates dominance rule)";
}
default_block->default_head_for = construct.get();
default_block->default_is_merge = default_block->pos == construct->end_pos;
// Map a case target to the list of values selecting that case.
std::unordered_map<uint32_t, tint::Vector<uint64_t, 4>> block_to_values;
tint::Vector<uint32_t, 4> case_targets;
std::unordered_set<uint64_t> case_values;
// Process case targets.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
if (case_values.count(value)) {
return Fail() << "Duplicate case value " << value << " in OpSwitch in block "
<< construct->begin_id;
}
case_values.insert(value);
if (block_to_values.count(case_target_id) == 0) {
case_targets.Push(case_target_id);
}
block_to_values[case_target_id].Push(value);
}
for (uint32_t case_target_id : case_targets) {
auto* case_block = GetBlockInfo(case_target_id);
case_block->case_values = std::move(block_to_values[case_target_id]);
// A case target can't be a back-edge.
if (construct->begin_pos >= case_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't preceded an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " can't be a back-edge";
}
// A case target can be the merge block, but can't go past it.
if (construct->end_pos < case_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " escapes the selection construct";
}
if (case_block->header_for_merge != 0 &&
case_block->header_for_merge != construct->begin_id) {
// The switch instruction for this case block is an alternate path to
// the merge block, and hence the merge block is not dominated by its
// own (different) header.
return Fail() << "Block " << case_block->id
<< " is a case block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< case_block->header_for_merge << " (violates dominance rule)";
}
// Mark the target as a case target.
if (case_block->case_head_for) {
// An OpSwitch must dominate its cases.
return Fail() << "Block " << case_target_id
<< " is declared as the switch case target for two OpSwitch "
"instructions, at blocks "
<< case_block->case_head_for->begin_id << " and "
<< construct->begin_id;
}
case_block->case_head_for = construct.get();
}
}
return success();
}
BlockInfo* FunctionEmitter::HeaderIfBreakable(const Construct* c) {
if (c == nullptr) {
return nullptr;
}
switch (c->kind) {
case Construct::kLoop:
case Construct::kSwitchSelection:
return GetBlockInfo(c->begin_id);
case Construct::kContinue: {
const auto* continue_target = GetBlockInfo(c->begin_id);
return GetBlockInfo(continue_target->header_for_continue);
}
default:
break;
}
return nullptr;
}
const Construct* FunctionEmitter::SiblingLoopConstruct(const Construct* c) const {
if (c == nullptr || c->kind != Construct::kContinue) {
return nullptr;
}
const uint32_t continue_target_id = c->begin_id;
const auto* continue_target = GetBlockInfo(continue_target_id);
const uint32_t header_id = continue_target->header_for_continue;
if (continue_target_id == header_id) {
// The continue target is the whole loop.
return nullptr;
}
const auto* candidate = GetBlockInfo(header_id)->construct;
// Walk up the construct tree until we hit the loop. In future
// we might handle the corner case where the same block is both a
// loop header and a selection header. For example, where the
// loop header block has a conditional branch going to distinct
// targets inside the loop body.
while (candidate && candidate->kind != Construct::kLoop) {
candidate = candidate->parent;
}
return candidate;
}
bool FunctionEmitter::ClassifyCFGEdges() {
if (failed()) {
return false;
}
// Checks validity of CFG edges leaving each basic block. This implicitly
// checks dominance rules for headers and continue constructs.
//
// For each branch encountered, classify each edge (S,T) as:
// - a back-edge
// - a structured exit (specific ways of branching to enclosing construct)
// - a normal (forward) edge, either natural control flow or a case fallthrough
//
// If more than one block is targeted by a normal edge, then S must be a
// structured header.
//
// Term: NEC(B) is the nearest enclosing construct for B.
//
// If edge (S,T) is a normal edge, and NEC(S) != NEC(T), then
// T is the header block of its NEC(T), and
// NEC(S) is the parent of NEC(T).
for (const auto src : block_order_) {
TINT_ASSERT(src > 0);
auto* src_info = GetBlockInfo(src);
TINT_ASSERT(src_info);
const auto src_pos = src_info->pos;
const auto& src_construct = *(src_info->construct);
// Compute the ordered list of unique successors.
tint::Vector<uint32_t, 4> successors;
{
std::unordered_set<uint32_t> visited;
src_info->basic_block->ForEachSuccessorLabel(
[&successors, &visited](const uint32_t succ) {
if (visited.count(succ) == 0) {
successors.Push(succ);
visited.insert(succ);
}
});
}
// There should only be one backedge per backedge block.
uint32_t num_backedges = 0;
// Track destinations for normal forward edges, either kForward or kCaseFallThrough.
// These count toward the need to have a merge instruction. We also track kIfBreak edges
// because when used with normal forward edges, we'll need to generate a flow guard
// variable.
tint::Vector<uint32_t, 4> normal_forward_edges;
tint::Vector<uint32_t, 4> if_break_edges;
if (successors.IsEmpty() && src_construct.enclosing_continue) {
// Kill and return are not allowed in a continue construct.
return Fail() << "Invalid function exit at block " << src
<< " from continue construct starting at "
<< src_construct.enclosing_continue->begin_id;
}
for (const auto dest : successors) {
const auto* dest_info = GetBlockInfo(dest);
// We've already checked terminators are valid.
TINT_ASSERT(dest_info);
const auto dest_pos = dest_info->pos;
// Insert the edge kind entry and keep a handle to update
// its classification.
EdgeKind& edge_kind = src_info->succ_edge[dest];
if (src_pos >= dest_pos) {
// This is a backedge.
edge_kind = EdgeKind::kBack;
num_backedges++;
const auto* continue_construct = src_construct.enclosing_continue;
if (!continue_construct) {
return Fail() << "Invalid backedge (" << src << "->" << dest << "): " << src
<< " is not in a continue construct";
}
if (src_pos != continue_construct->end_pos - 1) {
return Fail() << "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< src_construct.begin_id << " (violates post-dominance rule)";
}
const auto* ct_info = GetBlockInfo(continue_construct->begin_id);
TINT_ASSERT(ct_info);
if (ct_info->header_for_continue != dest) {
return Fail() << "Invalid backedge (" << src << "->" << dest
<< "): does not branch to the corresponding loop header, "
"expected "
<< ct_info->header_for_continue;
}
} else {
// This is a forward edge.
// For now, classify it that way, but we might update it.
edge_kind = EdgeKind::kForward;
// Exit from a continue construct can only be from the last block.
const auto* continue_construct = src_construct.enclosing_continue;
if (continue_construct != nullptr) {
if (continue_construct->ContainsPos(src_pos) &&
!continue_construct->ContainsPos(dest_pos) &&
(src_pos != continue_construct->end_pos - 1)) {
return Fail()
<< "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< continue_construct->begin_id << " (violates post-dominance rule)";
}
}
// Check valid structured exit cases.
if (edge_kind == EdgeKind::kForward) {
// Check for a 'break' from a loop or from a switch.
const auto* breakable_header =
HeaderIfBreakable(src_construct.enclosing_loop_or_continue_or_switch);
if (breakable_header != nullptr) {
if (dest == breakable_header->merge_for_header) {
// It's a break.
edge_kind =
(breakable_header->construct->kind == Construct::kSwitchSelection)
? EdgeKind::kSwitchBreak
: EdgeKind::kLoopBreak;
}
}
}
if (edge_kind == EdgeKind::kForward) {
// Check for a 'continue' from within a loop.
const auto* loop_header = HeaderIfBreakable(src_construct.enclosing_loop);
if (loop_header != nullptr) {
if (dest == loop_header->continue_for_header) {
// It's a continue.
edge_kind = EdgeKind::kLoopContinue;
}
}
}
if (edge_kind == EdgeKind::kForward) {
const auto& header_info = *GetBlockInfo(src_construct.begin_id);
if (dest == header_info.merge_for_header) {
// Branch to construct's merge block. The loop break and
// switch break cases have already been covered.
edge_kind = EdgeKind::kIfBreak;
}
}
// A forward edge into a case construct that comes from something
// other than the OpSwitch is actually a fallthrough.
if (edge_kind == EdgeKind::kForward) {
const auto* switch_construct =
(dest_info->case_head_for ? dest_info->case_head_for
: dest_info->default_head_for);
if (switch_construct != nullptr) {
if (src != switch_construct->begin_id) {
edge_kind = EdgeKind::kCaseFallThrough;
}
}
}
// The edge-kind has been finalized.
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
normal_forward_edges.Push(dest);
}
if (edge_kind == EdgeKind::kIfBreak) {
if_break_edges.Push(dest);
}
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
// Check for an invalid forward exit out of this construct.
if (dest_info->pos > src_construct.end_pos) {
// In most cases we're bypassing the merge block for the source
// construct.
auto end_block = src_construct.end_id;
const char* end_block_desc = "merge block";
if (src_construct.kind == Construct::kLoop) {
// For a loop construct, we have two valid places to go: the
// continue target or the merge for the loop header, which is
// further down.
const auto loop_merge =
GetBlockInfo(src_construct.begin_id)->merge_for_header;
if (dest_info->pos >= GetBlockInfo(loop_merge)->pos) {
// We're bypassing the loop's merge block.
end_block = loop_merge;
} else {
// We're bypassing the loop's continue target, and going into
// the middle of the continue construct.
end_block_desc = "continue target";
}
}
return Fail() << "Branch from block " << src << " to block " << dest
<< " is an invalid exit from construct starting at block "
<< src_construct.begin_id << "; branch bypasses "
<< end_block_desc << " " << end_block;
}
// Check dominance.
// Look for edges that violate the dominance condition: a branch
// from X to Y where:
// If Y is in a nearest enclosing continue construct headed by
// CT:
// Y is not CT, and
// In the structured order, X appears before CT order or
// after CT's backedge block.
// Otherwise, if Y is in a nearest enclosing construct
// headed by H:
// Y is not H, and
// In the structured order, X appears before H or after H's
// merge block.
const auto& dest_construct = *(dest_info->construct);
if (dest != dest_construct.begin_id && !dest_construct.ContainsPos(src_pos)) {
return Fail()
<< "Branch from " << src << " to " << dest << " bypasses "
<< (dest_construct.kind == Construct::kContinue ? "continue target "
: "header ")
<< dest_construct.begin_id << " (dominance rule violated)";
}
}
// Error on the fallthrough at the end in order to allow the better error messages
// from the above checks to happen.
if (edge_kind == EdgeKind::kCaseFallThrough) {
return Fail() << "Fallthrough not permitted in WGSL";
}
} // end forward edge
} // end successor
if (num_backedges > 1) {
return Fail() << "Block " << src << " has too many backedges: " << num_backedges;
}
if ((normal_forward_edges.Length() > 1) && (src_info->merge_for_header == 0)) {
return Fail() << "Control flow diverges at block " << src << " (to "
<< normal_forward_edges[0] << ", " << normal_forward_edges[1]
<< ") but it is not a structured header (it has no merge "
"instruction)";
}
if ((normal_forward_edges.Length() + if_break_edges.Length() > 1) &&
(src_info->merge_for_header == 0)) {
// There is a branch to the merge of an if-selection combined
// with an other normal forward branch. Control within the
// if-selection needs to be gated by a flow predicate.
for (auto if_break_dest : if_break_edges) {
auto* head_info = GetBlockInfo(GetBlockInfo(if_break_dest)->header_for_merge);
// Generate a guard name, but only once.
if (head_info->flow_guard_name.empty()) {
const std::string guard = "guard" + std::to_string(head_info->id);
head_info->flow_guard_name = namer_.MakeDerivedName(guard);
}
}
}
}
return success();
}
bool FunctionEmitter::FindIfSelectionInternalHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kIfSelection) {
continue;
}
auto* if_header_info = GetBlockInfo(construct->begin_id);
const auto* branch = if_header_info->basic_block->terminator();
const auto true_head = branch->GetSingleWordInOperand(1);
const auto false_head = branch->GetSingleWordInOperand(2);
auto* true_head_info = GetBlockInfo(true_head);
auto* false_head_info = GetBlockInfo(false_head);
const auto true_head_pos = true_head_info->pos;
const auto false_head_pos = false_head_info->pos;
const bool contains_true = construct->ContainsPos(true_head_pos);
const bool contains_false = construct->ContainsPos(false_head_pos);
// The cases for each edge are:
// - kBack: invalid because it's an invalid exit from the selection
// - kSwitchBreak ; record this for later special processing
// - kLoopBreak ; record this for later special processing
// - kLoopContinue ; record this for later special processing
// - kIfBreak; normal case, may require a guard variable.
// - kFallThrough; invalid exit from the selection
// - kForward; normal case
if_header_info->true_kind = if_header_info->succ_edge[true_head];
if_header_info->false_kind = if_header_info->succ_edge[false_head];
if (contains_true) {
if_header_info->true_head = true_head;
}
if (contains_false) {
if_header_info->false_head = false_head;
}
if (contains_true && (true_head_info->header_for_merge != 0) &&
(true_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the true head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << true_head << " is the true branch for if-selection header "
<< construct->begin_id << " and also the merge block for header block "
<< true_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_false && (false_head_info->header_for_merge != 0) &&
(false_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the false head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << false_head
<< " is the false branch for if-selection header " << construct->begin_id
<< " and also the merge block for header block "
<< false_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_true && contains_false && (true_head_pos != false_head_pos)) {
// This construct has both a "then" clause and an "else" clause.
//
// We have this structure:
//
// Option 1:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause)
// ...
// * end-false-clause
// * premerge-head
// ...
// * selection merge
//
// Option 2:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause) and also premerge-head
// ...
// * end-false-clause
// * selection merge
//
// Option 3:
//
// * condbranch
// * false-head (start of else-clause)
// ...
// * end-else-clause
// * true-head (start of then-clause) and also premerge-head
// ...
// * end-then-clause
// * selection merge
//
// The premerge-head exists if there is a kForward branch from the end
// of the first clause to a block within the surrounding selection.
// The first clause might be a then-clause or an else-clause.
const auto second_head = std::max(true_head_pos, false_head_pos);
const auto end_first_clause_pos = second_head - 1;
TINT_ASSERT(end_first_clause_pos < block_order_.size());
const auto end_first_clause = block_order_[end_first_clause_pos];
uint32_t premerge_id = 0;
uint32_t if_break_id = 0;
for (auto& then_succ_iter : GetBlockInfo(end_first_clause)->succ_edge) {
const uint32_t dest_id = then_succ_iter.first;
const auto edge_kind = then_succ_iter.second;
switch (edge_kind) {
case EdgeKind::kIfBreak:
if_break_id = dest_id;
break;
case EdgeKind::kForward: {
if (construct->ContainsPos(GetBlockInfo(dest_id)->pos)) {
// It's a premerge.
if (premerge_id != 0) {
// TODO(dneto): I think this is impossible to trigger at this
// point in the flow. It would require a merge instruction to
// get past the check of "at-most-one-forward-edge".
return Fail()
<< "invalid structure: then-clause headed by block "
<< true_head << " ending at block " << end_first_clause
<< " has two forward edges to within selection"
<< " going to " << premerge_id << " and " << dest_id;
}
premerge_id = dest_id;
auto* dest_block_info = GetBlockInfo(dest_id);
if_header_info->premerge_head = dest_id;
if (dest_block_info->header_for_merge != 0) {
// Premerge has two edges coming into it, from the then-clause
// and the else-clause. It's also, by construction, not the
// merge block of the if-selection. So it must not be a merge
// block itself. The OpBranchConditional instruction for the
// false head block is an alternate path to the merge block, and
// hence the merge block is not dominated by its own (different)
// header.
return Fail()
<< "Block " << premerge_id << " is the merge block for "
<< dest_block_info->header_for_merge
<< " but has alternate paths reaching it, starting from"
<< " blocks " << true_head << " and " << false_head
<< " which are the true and false branches for the"
<< " if-selection header block " << construct->begin_id
<< " (violates dominance rule)";
}
}
break;
}
default:
break;
}
}
if (if_break_id != 0 && premerge_id != 0) {
return Fail() << "Block " << end_first_clause << " in if-selection headed at block "
<< construct->begin_id << " branches to both the merge block "
<< if_break_id << " and also to block " << premerge_id
<< " later in the selection";
}
}
}
return success();
}
bool FunctionEmitter::EmitFunctionVariables() {
if (failed()) {
return false;
}
for (auto& inst : *function_.entry()) {
if (opcode(inst) != spv::Op::OpVariable) {
continue;
}
auto* var_store_type = GetVariableStoreType(inst);
if (failed()) {
return false;
}
const ast::Expression* initializer = nullptr;
if (inst.NumInOperands() > 1) {
// SPIR-V initializers are always constants.
// (OpenCL also allows the ID of an OpVariable, but we don't handle that
// here.)
initializer = parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)).expr;
if (!initializer) {
return false;
}
}
auto* var = parser_impl_.MakeVar(inst.result_id(), core::AddressSpace::kUndefined,
core::Access::kUndefined, var_store_type, initializer,
Attributes{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt);
auto* var_type = ty_.Reference(core::AddressSpace::kUndefined, var_store_type);
identifier_types_.emplace(inst.result_id(), var_type);
}
return success();
}
TypedExpression FunctionEmitter::AddressOfIfNeeded(TypedExpression expr,
const spvtools::opt::Instruction* inst) {
if (inst && expr) {
if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) {
if (expr.type->Is<Reference>() && spirv_type->AsPointer()) {
return AddressOf(expr);
}
}
}
return expr;
}
TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
if (failed()) {
return {};
}
switch (GetSkipReason(id)) {
case SkipReason::kDontSkip:
break;
case SkipReason::kOpaqueObject:
Fail() << "internal error: unhandled use of opaque object with ID: " << id;
return {};
case SkipReason::kSinkPointerIntoUse: {
// Replace the pointer with its source reference expression.
auto source_expr = GetDefInfo(id)->sink_pointer_source_expr;
TINT_ASSERT(source_expr.type->Is<Reference>());
return source_expr;
}
case SkipReason::kPointSizeBuiltinValue: {
return {ty_.F32(), create<ast::FloatLiteralExpression>(
Source{}, 1.0, ast::FloatLiteralExpression::Suffix::kF)};
}
case SkipReason::kPointSizeBuiltinPointer:
Fail() << "unhandled use of a pointer to the PointSize builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskInBuiltinPointer:
Fail() << "unhandled use of a pointer to the SampleMask builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskOutBuiltinPointer: {
// The result type is always u32.
auto name = namer_.Name(sample_mask_out_id);
return TypedExpression{ty_.U32(), builder_.Expr(Source{}, name)};
}
}
auto type_it = identifier_types_.find(id);
if (type_it != identifier_types_.end()) {
// We have a local named definition: function parameter, let, or var
// declaration.
auto name = namer_.Name(id);
auto* type = type_it->second;
return TypedExpression{type, builder_.Expr(Source{}, name)};
}
if (parser_impl_.IsScalarSpecConstant(id)) {
auto name = namer_.Name(id);
return TypedExpression{parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
builder_.Expr(Source{}, name)};
}
if (singly_used_values_.count(id)) {
auto expr = std::move(singly_used_values_[id]);
singly_used_values_.erase(id);
return expr;
}
const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id);
if (spirv_constant) {
return parser_impl_.MakeConstantExpression(id);
}
const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) {
Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
return {};
}
switch (opcode(inst)) {
case spv::Op::OpVariable: {
// This occurs for module-scope variables.
auto name = namer_.Name(id);
// Construct the reference type, mapping storage class correctly.
const auto* type =
RemapPointerProperties(parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref), id);
return TypedExpression{type, builder_.Expr(Source{}, name)};
}
case spv::Op::OpUndef:
// Substitute a null value for undef.
// This case occurs when OpUndef appears at module scope, as if it were
// a constant.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
default:
break;
}
if (const spvtools::opt::BasicBlock* const bb = ir_context_.get_instr_block(id)) {
if (auto* block = GetBlockInfo(bb->id())) {
if (block->pos == kInvalidBlockPos) {
// The value came from a block not in the block order.
// Substitute a null value.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
}
}
}
Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
return {};
}
bool FunctionEmitter::EmitFunctionBodyStatements() {
// Dump the basic blocks in order, grouped by construct.
// We maintain a stack of StatementBlock objects, where new statements
// are always written to the topmost entry of the stack. By this point in
// processing, we have already recorded the interesting control flow
// boundaries in the BlockInfo and associated Construct objects. As we
// enter a new statement grouping, we push onto the stack, and also schedule
// the statement block's completion and removal at a future block's ID.
// Upon entry, the statement stack has one entry representing the whole
// function.
TINT_ASSERT(!constructs_.IsEmpty());
Construct* function_construct = constructs_[0].get();
TINT_ASSERT(function_construct != nullptr);
TINT_ASSERT(function_construct->kind == Construct::kFunction);
// Make the first entry valid by filling in the construct field, which
// had not been computed at the time the entry was first created.
// TODO(dneto): refactor how the first construct is created vs.
// this statements stack entry is populated.
TINT_ASSERT(statements_stack_.Length() == 1);
statements_stack_[0].SetConstruct(function_construct);
for (auto block_id : block_order()) {
if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
return false;
}
}
return success();
}
bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
// Close off previous constructs.
while (!statements_stack_.IsEmpty() && (statements_stack_.Back().GetEndId() == block_info.id)) {
statements_stack_.Back().Finalize(&builder_);
statements_stack_.Pop();
}
if (statements_stack_.IsEmpty()) {
return Fail() << "internal error: statements stack empty at block " << block_info.id;
}
// Enter new constructs.
tint::Vector<const Construct*, 4> entering_constructs; // inner most comes first
{
auto* here = block_info.construct;
auto* const top_construct = statements_stack_.Back().GetConstruct();
while (here != top_construct) {
// Only enter a construct at its header block.
if (here->begin_id == block_info.id) {
entering_constructs.Push(here);
}
here = here->parent;
}
}
// What constructs can we have entered?
// - It can't be kFunction, because there is only one of those, and it was
// already on the stack at the outermost level.
// - We have at most one of kSwitchSelection, or kLoop because each of those
// is headed by a block with a merge instruction (OpLoopMerge for kLoop,
// and OpSelectionMerge for kSwitchSelection).
// - When there is a kIfSelection, it can't contain another construct,
// because both would have to have their own distinct merge instructions
// and distinct terminators.
// - A kContinue can contain a kContinue
// This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
// that a block can be continue target for at most one header block. See
// test BlockIsContinueForMoreThanOneHeader. If we generalize this,
// then by a dominance argument, the inner loop continue target can only be
// a single-block loop.
// TODO(dneto): Handle this case.
// - If a kLoop is on the outside, its terminator is either:
// - an OpBranch, in which case there is no other construct.
// - an OpBranchConditional, in which case there is either an kIfSelection
// (when both branch targets are different and are inside the loop),
// or no other construct (because the branch targets are the same,
// or one of them is a break or continue).
// - All that's left is a kContinue on the outside, and one of
// kIfSelection, kSwitchSelection, kLoop on the inside.
//
// The kContinue can be the parent of the other. For example, a selection
// starting at the first block of a continue construct.
//
// The kContinue can't be the child of the other because either:
// - The other can't be kLoop because:
// - If the kLoop is for a different loop then the kContinue, then
// the kContinue must be its own loop header, and so the same
// block is two different loops. That's a contradiction.
// - If the kLoop is for a the same loop, then this is a contradiction
// because a kContinue and its kLoop have disjoint block sets.
// - The other construct can't be a selection because:
// - The kContinue construct is the entire loop, i.e. the continue
// target is its own loop header block. But then the continue target
// has an OpLoopMerge instruction, which contradicts this block being
// a selection header.
// - The kContinue is in a multi-block loop that is has a non-empty
// kLoop; and the selection contains the kContinue block but not the
// loop block. That breaks dominance rules. That is, the continue
// target is dominated by that loop header, and so gets found by the
// block traversal on the outside before the selection is found. The
// selection is inside the outer loop.
//
// So we fall into one of the following cases:
// - We are entering 0 or 1 constructs, or
// - We are entering 2 constructs, with the outer one being a kContinue or
// kLoop, the inner one is not a continue.
if (entering_constructs.Length() > 2) {
return Fail() << "internal error: bad construct nesting found";
}
if (entering_constructs.Length() == 2) {
auto inner_kind = entering_constructs[0]->kind;
auto outer_kind = entering_constructs[1]->kind;
if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Only a Continue "
"or a Loop construct can be outer construct on same block. "
"Got outer kind "
<< int(outer_kind) << " inner kind " << int(inner_kind);
}
if (inner_kind == Construct::kContinue) {
return Fail() << "internal error: unsupported construct nesting: "
"Continue around Continue";
}
if (inner_kind != Construct::kIfSelection && inner_kind != Construct::kSwitchSelection &&
inner_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Continue around "
"something other than if, switch, or loop";
}
}
// Enter constructs from outermost to innermost.
// kLoop and kContinue push a new statement-block onto the stack before
// emitting statements in the block.
// kIfSelection and kSwitchSelection emit statements in the block and then
// emit push a new statement-block. Only emit the statements in the block
// once.
// Have we emitted the statements for this block?
bool emitted = false;
// When entering an if-selection or switch-selection, we will emit the WGSL
// construct to cause the divergent branching. But otherwise, we will
// emit a "normal" block terminator, which occurs at the end of this method.
bool has_normal_terminator = true;
for (auto iter = entering_constructs.rbegin(); iter != entering_constructs.rend(); ++iter) {
const Construct* construct = *iter;
switch (construct->kind) {
case Construct::kFunction:
return Fail() << "internal error: nested function construct";
case Construct::