blob: 95dd5ffc687546ff39ebcb85976a419de1907938 [file] [log] [blame]
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_RESOLVER_RESOLVER_H_
#define SRC_TINT_RESOLVER_RESOLVER_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/builtin_table.h"
#include "src/tint/program_builder.h"
#include "src/tint/resolver/dependency_graph.h"
#include "src/tint/resolver/sem_helper.h"
#include "src/tint/scope_stack.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/unique_vector.h"
// Forward declarations
namespace tint::ast {
class IndexAccessorExpression;
class BinaryExpression;
class BitcastExpression;
class CallExpression;
class CallStatement;
class CaseStatement;
class ForLoopStatement;
class Function;
class IdentifierExpression;
class LoopStatement;
class MemberAccessorExpression;
class ReturnStatement;
class SwitchStatement;
class UnaryOpExpression;
class Variable;
} // namespace tint::ast
namespace tint::sem {
class Array;
class Atomic;
class BlockStatement;
class Builtin;
class CaseStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class LoopStatement;
class Statement;
class SwitchStatement;
class TypeConstructor;
} // namespace tint::sem
namespace tint::resolver {
/// Resolves types for all items in the given tint program
class Resolver {
public:
/// Constructor
/// @param builder the program builder
explicit Resolver(ProgramBuilder* builder);
/// Destructor
~Resolver();
/// @returns error messages from the resolver
std::string error() const { return diagnostics_.str(); }
/// @returns true if the resolver was successful
bool Resolve();
/// @param type the given type
/// @returns true if the given type is a plain type
bool IsPlain(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is a fixed-footprint type
bool IsFixedFootprint(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is storable
bool IsStorable(const sem::Type* type) const;
/// @param type the given type
/// @returns true if the given type is host-shareable
bool IsHostShareable(const sem::Type* type) const;
private:
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
using ValidTypeStorageLayouts =
std::set<std::pair<const sem::Type*, ast::StorageClass>>;
ValidTypeStorageLayouts valid_type_storage_layouts_;
/// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
struct BlockInfo {
enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent);
~BlockInfo();
template <typename Pred>
BlockInfo* FindFirstParent(Pred&& pred) {
BlockInfo* curr = this;
while (curr && !pred(curr)) {
curr = curr->parent;
}
return curr;
}
BlockInfo* FindFirstParent(BlockInfo::Type ty) {
return FindFirstParent(
[ty](auto* block_info) { return block_info->type == ty; });
}
ast::BlockStatement const* const block;
const Type type;
BlockInfo* const parent;
std::vector<const ast::Variable*> decls;
// first_continue is set to the index of the first variable in decls
// declared after the first continue statement in a loop block, if any.
constexpr static size_t kNoContinue = size_t(~0);
size_t first_continue = kNoContinue;
};
// Structure holding information for a TypeDecl
struct TypeDeclInfo {
ast::TypeDecl const* const ast;
sem::Type* const sem;
};
/// Resolves the program, without creating final the semantic nodes.
/// @returns true on success, false on error
bool ResolveInternal();
/// Creates the nodes and adds them to the sem::Info mappings of the
/// ProgramBuilder.
void CreateSemanticNodes() const;
/// Retrieves information for the requested import.
/// @param src the source of the import
/// @param path the import path
/// @param name the method name to get information on
/// @param params the parameters to the method call
/// @param id out parameter for the external call ID. Must not be a nullptr.
/// @returns the return type of `name` in `path` or nullptr on error.
sem::Type* GetImportData(const Source& src,
const std::string& path,
const std::string& name,
const ast::ExpressionList& params,
uint32_t* id);
//////////////////////////////////////////////////////////////////////////////
// AST and Type traversal methods
//////////////////////////////////////////////////////////////////////////////
// Expression resolving methods
// Returns the semantic node pointer on success, nullptr on failure.
sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*);
sem::Expression* Binary(const ast::BinaryExpression*);
sem::Expression* Bitcast(const ast::BitcastExpression*);
sem::Call* Call(const ast::CallExpression*);
sem::Expression* Expression(const ast::Expression*);
sem::Function* Function(const ast::Function*);
sem::Call* FunctionCall(const ast::CallExpression*,
sem::Function* target,
const std::vector<const sem::Expression*> args,
sem::Behaviors arg_behaviors);
sem::Expression* Identifier(const ast::IdentifierExpression*);
sem::Call* BuiltinCall(const ast::CallExpression*,
sem::BuiltinType,
const std::vector<const sem::Expression*> args,
const std::vector<const sem::Type*> arg_tys);
sem::Expression* Literal(const ast::LiteralExpression*);
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
sem::Call* TypeConversion(const ast::CallExpression* expr,
const sem::Type* ty,
const sem::Expression* arg,
const sem::Type* arg_ty);
sem::Call* TypeConstructor(const ast::CallExpression* expr,
const sem::Type* ty,
const std::vector<const sem::Expression*> args,
const std::vector<const sem::Type*> arg_tys);
sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
// Statement resolving methods
// Each return true on success, false on failure.
sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
sem::Statement* CompoundAssignmentStatement(
const ast::CompoundAssignmentStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
sem::GlobalVariable* GlobalVariable(const ast::Variable*);
sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*);
sem::Statement* IncrementDecrementStatement(
const ast::IncrementDecrementStatement*);
sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*);
sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&);
// Resolve the result type of a binary operator.
// Returns nullptr if the types are not valid for this operator.
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidatePipelineStages() const;
bool ValidateAlias(const ast::Alias*) const;
bool ValidateArray(const sem::Array* arr, const Source& source) const;
bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size,
uint32_t el_align,
const Source& source) const;
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const;
bool ValidateAtomicVariable(const sem::Variable* var) const;
bool ValidateAssignment(const ast::Statement* a,
const sem::Type* rhs_ty) const;
bool ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) const;
bool ValidateBreakStatement(const sem::Statement* stmt) const;
bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_type,
ast::PipelineStage stage,
const bool is_input) const;
bool ValidateContinueStatement(const sem::Statement* stmt) const;
bool ValidateDiscardStatement(const sem::Statement* stmt) const;
bool ValidateElseStatement(const sem::ElseStatement* stmt) const;
bool ValidateEntryPoint(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const;
bool ValidateFallthroughStatement(const sem::Statement* stmt) const;
bool ValidateFunction(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateFunctionCall(const sem::Call* call) const;
bool ValidateGlobalVariable(const sem::Variable* var) const;
bool ValidateIfStatement(const sem::IfStatement* stmt) const;
bool ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt) const;
bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type) const;
bool ValidateBuiltinCall(const sem::Call* call) const;
bool ValidateLocationAttribute(const ast::LocationAttribute* location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
bool ValidateLoopStatement(const sem::LoopStatement* stmt) const;
bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const;
bool ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var) const;
bool ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type,
const sem::Type* ret_type) const;
bool ValidateStatements(const ast::StatementList& stmts) const;
bool ValidateStorageTexture(const ast::StorageTexture* t) const;
bool ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) const;
bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type) const;
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var) const;
bool ValidateVariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const sem::Type* rhs_type) const;
bool ValidateVector(const sem::Vector* ty, const Source& source) const;
bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type) const;
bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type) const;
bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type) const;
bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type) const;
bool ValidateTextureBuiltinFunction(const sem::Call* call) const;
bool ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) const;
bool ValidateStorageClassLayout(const sem::Type* type,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layouts) const;
bool ValidateStorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layouts) const;
/// @returns true if the attribute list contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
bool IsValidationDisabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// @returns true if the attribute list does not contains a
/// ast::DisableValidationAttribute with the validation mode equal to
/// `validation`
bool IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
/// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_
bool WorkgroupSize(const ast::Function*);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
/// returned.
/// @param ty the ast::Type
sem::Type* Type(const ast::Type* ty);
/// @param named_type the named type to resolve
/// @returns the resolved semantic type
sem::Type* TypeDecl(const ast::TypeDecl* named_type);
/// Builds and returns the semantic information for the array `arr`.
/// This method does not mark the ast::Array node, nor attach the generated
/// semantic information to the AST node.
/// @returns the semantic Array information, or nullptr if an error is
/// raised.
/// @param arr the Array to get semantic information for
sem::Array* Array(const ast::Array* arr);
/// Builds and returns the semantic information for the alias `alias`.
/// This method does not mark the ast::Alias node, nor attach the generated
/// semantic information to the AST node.
/// @returns the aliased type, or nullptr if an error is raised.
sem::Type* Alias(const ast::Alias* alias);
/// Builds and returns the semantic information for the structure `str`.
/// This method does not mark the ast::Struct node, nor attach the generated
/// semantic information to the AST node.
/// @returns the semantic Struct information, or nullptr if an error is
/// raised.
sem::Struct* Structure(const ast::Struct* str);
/// @returns the semantic info for the variable `var`. If an error is
/// raised, nullptr is returned.
/// @note this method does not resolve the attributes as these are
/// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring
/// @param index the index of the parameter, if this variable is a parameter
sem::Variable* Variable(const ast::Variable* var,
VariableKind kind,
uint32_t index = 0);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
/// given storage class, erroring if it cannot.
/// @param sc the storage class to apply to the type and transitent types
/// @param ty the type to apply the storage class on
/// @param usage the Source of the root variable declaration that uses the
/// given type and storage class. Used for generating sensible error
/// messages.
/// @returns true on success, false on error
bool ApplyStorageClassUsageToType(ast::StorageClass sc,
sem::Type* ty,
const Source& usage);
/// @param storage_class the storage class
/// @returns the default access control for the given storage class
ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
/// Allocate constant IDs for pipeline-overridable constants.
void AllocateOverridableConstantIds();
/// Set the shadowing information on variable declarations.
/// @note this method must only be called after all semantic nodes are built.
void SetShadows();
/// StatementScope() does the following:
/// * Creates the AST -> SEM mapping.
/// * Assigns `sem` to #current_statement_
/// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
/// sem::CompoundStatement.
/// * Assigns `sem` to #current_block_ if `sem` derives from
/// sem::BlockStatement.
/// * Then calls `callback`.
/// * Before returning #current_statement_, #current_compound_statement_, and
/// #current_block_ are restored to their original values.
/// @returns `sem` if `callback` returns true, otherwise `nullptr`.
template <typename SEM, typename F>
SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Mark records that the given AST node has been visited, and asserts that
/// the given node has not already been seen. Diamonds in the AST are
/// illegal.
/// @param node the AST node.
/// @returns true on success, false on error
bool Mark(const ast::Node* node);
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
/// Adds the given warning message to the diagnostics
void AddWarning(const std::string& msg, const Source& source) const;
/// Adds the given note message to the diagnostics
void AddNote(const std::string& msg, const Source& source) const;
//////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
//////////////////////////////////////////////////////////////////////////////
/// Cast `Value` to `target_type`
/// @return the casted value
sem::Constant ConstantCast(const sem::Constant& value,
const sem::Type* target_elem_type);
sem::Constant EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* type);
/// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const;
/// @returns true if `expr` is the current CallStatement's CallExpression
bool IsCallStatement(const ast::Expression* expr) const;
/// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that
/// (transitively) owns the current statement.
/// @param stop_at_loop if true then the function will return nullptr if a
/// loop or for-loop was found before the continuing.
const ast::Statement* ClosestContinuing(bool stop_at_loop) const;
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
: nullptr;
}
struct TypeConversionSig {
const sem::Type* target;
const sem::Type* source;
bool operator==(const TypeConversionSig&) const;
/// Hasher provides a hash function for the TypeConversionSig
struct Hasher {
/// @param sig the TypeConversionSig to create a hash for
/// @return the hash value
std::size_t operator()(const TypeConversionSig& sig) const;
};
};
struct TypeConstructorSig {
const sem::Type* type;
const std::vector<const sem::Type*> parameters;
TypeConstructorSig(const sem::Type* ty,
const std::vector<const sem::Type*> params);
TypeConstructorSig(const TypeConstructorSig&);
~TypeConstructorSig();
bool operator==(const TypeConstructorSig&) const;
/// Hasher provides a hash function for the TypeConstructorSig
struct Hasher {
/// @param sig the TypeConstructorSig to create a hash for
/// @return the hash value
std::size_t operator()(const TypeConstructorSig& sig) const;
};
};
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<BuiltinTable> const builtin_table_;
DependencyGraph dependencies_;
SemHelper sem_;
std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
std::unordered_set<const ast::Node*> marked_;
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
std::unordered_map<TypeConversionSig,
sem::CallTarget*,
TypeConversionSig::Hasher>
type_conversions_;
std::unordered_map<TypeConstructorSig,
sem::CallTarget*,
TypeConstructorSig::Hasher>
type_ctors_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
sem::BlockStatement* current_block_ = nullptr;
};
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_RESOLVER_H_