blob: 5972a829e5eb6c29da08002ddf44d7233b50be71 [file] [log] [blame]
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/hlsl.h"
#include <utility>
#include "src/ast/stage_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/decompose_storage_access.h"
#include "src/transform/manager.h"
namespace tint {
namespace transform {
namespace {
// This list is used for a binary search and must be kept in sorted order.
const char* kReservedKeywords[] = {"AddressU",
"AddressV",
"AddressW",
"AllMemoryBarrier",
"AllMemoryBarrierWithGroupSync",
"AppendStructuredBuffer",
"BINORMAL",
"BLENDINDICES",
"BLENDWEIGHT",
"BlendState",
"BorderColor",
"Buffer",
"ByteAddressBuffer",
"COLOR",
"CheckAccessFullyMapped",
"ComparisonFunc",
"CompileShader",
"ComputeShader",
"ConsumeStructuredBuffer",
"D3DCOLORtoUBYTE4",
"DEPTH",
"DepthStencilState",
"DepthStencilView",
"DeviceMemoryBarrier",
"DeviceMemroyBarrierWithGroupSync",
"DomainShader",
"EvaluateAttributeAtCentroid",
"EvaluateAttributeAtSample",
"EvaluateAttributeSnapped",
"FOG",
"Filter",
"GeometryShader",
"GetRenderTargetSampleCount",
"GetRenderTargetSamplePosition",
"GroupMemoryBarrier",
"GroupMemroyBarrierWithGroupSync",
"Hullshader",
"InputPatch",
"InterlockedAdd",
"InterlockedAnd",
"InterlockedCompareExchange",
"InterlockedCompareStore",
"InterlockedExchange",
"InterlockedMax",
"InterlockedMin",
"InterlockedOr",
"InterlockedXor",
"LineStream",
"MaxAnisotropy",
"MaxLOD",
"MinLOD",
"MipLODBias",
"NORMAL",
"NULL",
"Normal",
"OutputPatch",
"POSITION",
"POSITIONT",
"PSIZE",
"PixelShader",
"PointStream",
"Process2DQuadTessFactorsAvg",
"Process2DQuadTessFactorsMax",
"Process2DQuadTessFactorsMin",
"ProcessIsolineTessFactors",
"ProcessQuadTessFactorsAvg",
"ProcessQuadTessFactorsMax",
"ProcessQuadTessFactorsMin",
"ProcessTriTessFactorsAvg",
"ProcessTriTessFactorsMax",
"ProcessTriTessFactorsMin",
"RWBuffer",
"RWByteAddressBuffer",
"RWStructuredBuffer",
"RWTexture1D",
"RWTexture2D",
"RWTexture2DArray",
"RWTexture3D",
"RasterizerState",
"RenderTargetView",
"SV_ClipDistance",
"SV_Coverage",
"SV_CullDistance",
"SV_Depth",
"SV_DepthGreaterEqual",
"SV_DepthLessEqual",
"SV_DispatchThreadID",
"SV_DomainLocation",
"SV_GSInstanceID",
"SV_GroupID",
"SV_GroupIndex",
"SV_GroupThreadID",
"SV_InnerCoverage",
"SV_InsideTessFactor",
"SV_InstanceID",
"SV_IsFrontFace",
"SV_OutputControlPointID",
"SV_Position",
"SV_PrimitiveID",
"SV_RenderTargetArrayIndex",
"SV_SampleIndex",
"SV_StencilRef",
"SV_Target",
"SV_TessFactor",
"SV_VertexArrayIndex",
"SV_VertexID",
"Sampler",
"Sampler1D",
"Sampler2D",
"Sampler3D",
"SamplerCUBE",
"StructuredBuffer",
"TANGENT",
"TESSFACTOR",
"TEXCOORD",
"Texcoord",
"Texture",
"Texture1D",
"Texture2D",
"Texture2DArray",
"Texture2DMS",
"Texture2DMSArray",
"Texture3D",
"TextureCube",
"TextureCubeArray",
"TriangleStream",
"VFACE",
"VPOS",
"VertexShader",
"abort",
"allow_uav_condition",
"asdouble",
"asfloat",
"asint",
"asm",
"asm_fragment",
"asuint",
"auto",
"bool",
"bool1",
"bool1x1",
"bool1x2",
"bool1x3",
"bool1x4",
"bool2",
"bool2x1",
"bool2x2",
"bool2x3",
"bool2x4",
"bool3",
"bool3x1",
"bool3x2",
"bool3x3",
"bool3x4",
"bool4",
"bool4x1",
"bool4x2",
"bool4x3",
"bool4x4",
"branch",
"break",
"call",
"case",
"catch",
"cbuffer",
"centroid",
"char",
"class",
"clip",
"column_major",
"compile_fragment",
"const",
"const_cast",
"continue",
"countbits",
"ddx",
"ddx_coarse",
"ddx_fine",
"ddy",
"ddy_coarse",
"ddy_fine",
"degrees",
"delete",
"discard",
"do",
"double",
"double1",
"double1x1",
"double1x2",
"double1x3",
"double1x4",
"double2",
"double2x1",
"double2x2",
"double2x3",
"double2x4",
"double3",
"double3x1",
"double3x2",
"double3x3",
"double3x4",
"double4",
"double4x1",
"double4x2",
"double4x3",
"double4x4",
"dst",
"dword",
"dword1",
"dword1x1",
"dword1x2",
"dword1x3",
"dword1x4",
"dword2",
"dword2x1",
"dword2x2",
"dword2x3",
"dword2x4",
"dword3",
"dword3x1",
"dword3x2",
"dword3x3",
"dword3x4",
"dword4",
"dword4x1",
"dword4x2",
"dword4x3",
"dword4x4",
"dynamic_cast",
"else",
"enum",
"errorf",
"explicit",
"export",
"extern",
"f16to32",
"f32tof16",
"false",
"fastopt",
"firstbithigh",
"firstbitlow",
"flatten",
"float",
"float1",
"float1x1",
"float1x2",
"float1x3",
"float1x4",
"float2",
"float2x1",
"float2x2",
"float2x3",
"float2x4",
"float3",
"float3x1",
"float3x2",
"float3x3",
"float3x4",
"float4",
"float4x1",
"float4x2",
"float4x3",
"float4x4",
"fmod",
"for",
"forcecase",
"frac",
"friend",
"fxgroup",
"goto",
"groupshared",
"half",
"half1",
"half1x1",
"half1x2",
"half1x3",
"half1x4",
"half2",
"half2x1",
"half2x2",
"half2x3",
"half2x4",
"half3",
"half3x1",
"half3x2",
"half3x3",
"half3x4",
"half4",
"half4x1",
"half4x2",
"half4x3",
"half4x4",
"if",
"in",
"inline",
"inout",
"int",
"int1",
"int1x1",
"int1x2",
"int1x3",
"int1x4",
"int2",
"int2x1",
"int2x2",
"int2x3",
"int2x4",
"int3",
"int3x1",
"int3x2",
"int3x3",
"int3x4",
"int4",
"int4x1",
"int4x2",
"int4x3",
"int4x4",
"interface",
"isfinite",
"isinf",
"isnan",
"lerp",
"lineadj",
"linear",
"lit",
"log10",
"long",
"loop",
"mad",
"matrix",
"min10float",
"min10float1",
"min10float1x1",
"min10float1x2",
"min10float1x3",
"min10float1x4",
"min10float2",
"min10float2x1",
"min10float2x2",
"min10float2x3",
"min10float2x4",
"min10float3",
"min10float3x1",
"min10float3x2",
"min10float3x3",
"min10float3x4",
"min10float4",
"min10float4x1",
"min10float4x2",
"min10float4x3",
"min10float4x4",
"min12int",
"min12int1",
"min12int1x1",
"min12int1x2",
"min12int1x3",
"min12int1x4",
"min12int2",
"min12int2x1",
"min12int2x2",
"min12int2x3",
"min12int2x4",
"min12int3",
"min12int3x1",
"min12int3x2",
"min12int3x3",
"min12int3x4",
"min12int4",
"min12int4x1",
"min12int4x2",
"min12int4x3",
"min12int4x4",
"min16float",
"min16float1",
"min16float1x1",
"min16float1x2",
"min16float1x3",
"min16float1x4",
"min16float2",
"min16float2x1",
"min16float2x2",
"min16float2x3",
"min16float2x4",
"min16float3",
"min16float3x1",
"min16float3x2",
"min16float3x3",
"min16float3x4",
"min16float4",
"min16float4x1",
"min16float4x2",
"min16float4x3",
"min16float4x4",
"min16int",
"min16int1",
"min16int1x1",
"min16int1x2",
"min16int1x3",
"min16int1x4",
"min16int2",
"min16int2x1",
"min16int2x2",
"min16int2x3",
"min16int2x4",
"min16int3",
"min16int3x1",
"min16int3x2",
"min16int3x3",
"min16int3x4",
"min16int4",
"min16int4x1",
"min16int4x2",
"min16int4x3",
"min16int4x4",
"min16uint",
"min16uint1",
"min16uint1x1",
"min16uint1x2",
"min16uint1x3",
"min16uint1x4",
"min16uint2",
"min16uint2x1",
"min16uint2x2",
"min16uint2x3",
"min16uint2x4",
"min16uint3",
"min16uint3x1",
"min16uint3x2",
"min16uint3x3",
"min16uint3x4",
"min16uint4",
"min16uint4x1",
"min16uint4x2",
"min16uint4x3",
"min16uint4x4",
"msad4",
"mul",
"mutable",
"namespace",
"new",
"nointerpolation",
"noise",
"noperspective",
"numthreads",
"operator",
"out",
"packoffset",
"pass",
"pixelfragment",
"pixelshader",
"point",
"precise",
"printf",
"private",
"protected",
"public",
"radians",
"rcp",
"refract",
"register",
"reinterpret_cast",
"return",
"row_major",
"rsqrt",
"sample",
"sampler",
"sampler1D",
"sampler2D",
"sampler3D",
"samplerCUBE",
"sampler_state",
"saturate",
"shared",
"short",
"signed",
"sincos",
"sizeof",
"snorm",
"stateblock",
"stateblock_state",
"static",
"static_cast",
"string",
"struct",
"switch",
"tbuffer",
"technique",
"technique10",
"technique11",
"template",
"tex1D",
"tex1Dbias",
"tex1Dgrad",
"tex1Dlod",
"tex1Dproj",
"tex2D",
"tex2Dbias",
"tex2Dgrad",
"tex2Dlod",
"tex2Dproj",
"tex3D",
"tex3Dbias",
"tex3Dgrad",
"tex3Dlod",
"tex3Dproj",
"texCUBE",
"texCUBEbias",
"texCUBEgrad",
"texCUBElod",
"texCUBEproj",
"texture",
"texture1D",
"texture1DArray",
"texture2D",
"texture2DArray",
"texture2DMS",
"texture2DMSArray",
"texture3D",
"textureCube",
"textureCubeArray",
"this",
"throw",
"transpose",
"triangle",
"triangleadj",
"true",
"try",
"typedef",
"typename",
"uint",
"uint1",
"uint1x1",
"uint1x2",
"uint1x3",
"uint1x4",
"uint2",
"uint2x1",
"uint2x2",
"uint2x3",
"uint2x4",
"uint3",
"uint3x1",
"uint3x2",
"uint3x3",
"uint3x4",
"uint4",
"uint4x1",
"uint4x2",
"uint4x3",
"uint4x4",
"uniform",
"union",
"unorm",
"unroll",
"unsigned",
"using",
"vector",
"vertexfragment",
"vertexshader",
"virtual",
"void",
"volatile",
"while"};
} // namespace
Hlsl::Hlsl() = default;
Hlsl::~Hlsl() = default;
Transform::Output Hlsl::Run(const Program* in, const DataMap& data) {
Manager manager;
manager.Add<CanonicalizeEntryPointIO>();
manager.Add<DecomposeStorageAccess>();
manager.Add<CalculateArrayLength>();
auto out = manager.Run(in, data);
if (!out.program.IsValid()) {
return out;
}
ProgramBuilder builder;
CloneContext ctx(&builder, &out.program);
PromoteInitializersToConstVar(ctx);
AddEmptyEntryPoint(ctx);
RenameReservedKeywords(&ctx, kReservedKeywords);
ctx.Clone();
return Output{Program(std::move(builder))};
}
void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
// Scan the AST nodes for array and structure initializers which
// need to be promoted to their own constant declaration.
// Note: Correct handling of nested expressions is guaranteed due to the
// depth-first traversal of the ast::Node::Clone() methods:
//
// The inner-most initializers are traversed first, and they are hoisted
// to const variables declared just above the statement of use. The outer
// initializer will then be hoisted, inserting themselves between the
// inner declaration and the statement of use. This pattern applies correctly
// to any nested depth.
//
// Depth-first traversal of the AST is guaranteed because AST nodes are fully
// immutable and require their children to be constructed first so their
// pointer can be passed to the parent's constructor.
for (auto* src_node : ctx.src->ASTNodes().Objects()) {
if (auto* src_init = src_node->As<ast::TypeConstructorExpression>()) {
auto* src_sem_expr = ctx.src->Sem().Get(src_init);
if (!src_sem_expr) {
TINT_ICE(ctx.dst->Diagnostics())
<< "ast::TypeConstructorExpression has no semantic expression node";
continue;
}
auto* src_sem_stmt = src_sem_expr->Stmt();
if (!src_sem_stmt) {
// Expression is outside of a statement. This usually means the
// expression is part of a global (module-scope) constant declaration.
// These must be constexpr, and so cannot contain the type of
// expressions that must be sanitized.
continue;
}
auto* src_stmt = src_sem_stmt->Declaration();
if (auto* src_var_decl = src_stmt->As<ast::VariableDeclStatement>()) {
if (src_var_decl->variable()->constructor() == src_init) {
// This statement is just a variable declaration with the initializer
// as the constructor value. This is what we're attempting to
// transform to, and so ignore.
continue;
}
}
auto* src_ty = src_sem_expr->Type();
if (src_ty->IsAnyOf<type::Array, type::Struct>()) {
// Create a new symbol for the constant
auto dst_symbol = ctx.dst->Symbols().New();
// Clone the type
auto* dst_ty = ctx.Clone(src_ty);
// Clone the initializer
auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the hoisted initializer
auto* dst_var = ctx.dst->Const(dst_symbol, dst_ty, dst_init);
// Construct the variable declaration statement
auto* dst_var_decl =
ctx.dst->create<ast::VariableDeclStatement>(dst_var);
// Construct the identifier for referencing the constant
auto* dst_ident = ctx.dst->Expr(dst_symbol);
// Insert the constant before the usage
ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
dst_var_decl);
// Replace the inlined initializer with a reference to the constant
ctx.Replace(src_init, dst_ident);
}
}
}
}
void Hlsl::AddEmptyEntryPoint(CloneContext& ctx) const {
for (auto* func : ctx.src->AST().Functions()) {
if (func->IsEntryPoint()) {
return;
}
}
ctx.dst->Func(
ctx.dst->Symbols().New("tint_unused_entry_point"), {},
ctx.dst->ty.void_(), {},
{ctx.dst->create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
}
} // namespace transform
} // namespace tint