[tint] Move validation code into a Validator class.
This CL moves the Validate methods from the Resolver into a specific
Validator class used by the Resolver.
Bug: tint:1313
Change-Id: Ida21a0cc65f2679739c8499de7065ff8b58c4efc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87150
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 980b16d..d7b2be3 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -376,9 +376,10 @@
"resolver/resolver.cc",
"resolver/resolver.h",
"resolver/resolver_constants.cc",
- "resolver/resolver_validation.cc",
"resolver/sem_helper.cc",
"resolver/sem_helper.h",
+ "resolver/validator.cc",
+ "resolver/validator.h",
"scope_stack.h",
"sem/array.h",
"sem/atomic_type.h",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index ae3f224..e2bb332 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -256,10 +256,11 @@
resolver/dependency_graph.h
resolver/resolver.cc
resolver/resolver_constants.cc
- resolver/resolver_validation.cc
resolver/resolver.h
resolver/sem_helper.cc
resolver/sem_helper.h
+ resolver/validator.cc
+ resolver/validator.h
scope_stack.h
sem/array.cc
sem/array.h
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 4f7c08d..5069e5c 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -85,7 +85,8 @@
: builder_(builder),
diagnostics_(builder->Diagnostics()),
builtin_table_(BuiltinTable::Create(*builder)),
- sem_(builder) {}
+ sem_(builder, dependencies_),
+ validator_(builder, sem_) {}
Resolver::~Resolver() = default;
@@ -138,7 +139,7 @@
SetShadows();
- if (!ValidatePipelineStages()) {
+ if (!validator_.PipelineStages(entry_points_)) {
return false;
}
@@ -172,7 +173,7 @@
}
if (auto* el = Type(t->type)) {
if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
- if (ValidateVector(vector, t->source)) {
+ if (validator_.Vector(vector, t->source)) {
return vector;
}
}
@@ -188,7 +189,7 @@
if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
if (auto* matrix =
builder_->create<sem::Matrix>(column_type, t->columns)) {
- if (ValidateMatrix(matrix, t->source)) {
+ if (validator_.Matrix(matrix, t->source)) {
return matrix;
}
}
@@ -200,7 +201,7 @@
[&](const ast::Atomic* t) -> sem::Atomic* {
if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(el);
- if (!ValidateAtomic(t, a)) {
+ if (!validator_.Atomic(t, a)) {
return nullptr;
}
return a;
@@ -240,7 +241,7 @@
},
[&](const ast::StorageTexture* t) -> sem::StorageTexture* {
if (auto* el = Type(t->type)) {
- if (!ValidateStorageTexture(t)) {
+ if (!validator_.StorageTexture(t)) {
return nullptr;
}
return builder_->create<sem::StorageTexture>(t->dim, t->format,
@@ -252,7 +253,7 @@
return builder_->create<sem::ExternalTexture>();
},
[&](Default) {
- auto* resolved = ResolvedSymbol(ty);
+ auto* resolved = sem_.ResolvedSymbol(ty);
return Switch(
resolved, //
[&](sem::Type* type) { return type; },
@@ -366,8 +367,8 @@
if (kind == VariableKind::kLocal && !var->is_const &&
storage_class != ast::StorageClass::kFunction &&
- IsValidationEnabled(var->attributes,
- ast::DisabledValidation::kIgnoreStorageClass)) {
+ validator_.IsValidationEnabled(
+ var->attributes, ast::DisabledValidation::kIgnoreStorageClass)) {
AddError("function variable has a non-function storage class", var->source);
return nullptr;
}
@@ -385,8 +386,8 @@
builder_->create<sem::Reference>(storage_ty, storage_class, access);
}
- if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
- rhs->Type())) {
+ if (rhs && !validator_.VariableConstructorOrCast(var, storage_class,
+ storage_ty, rhs->Type())) {
return nullptr;
}
@@ -547,17 +548,17 @@
}
}
- if (!ValidateNoDuplicateAttributes(var->attributes)) {
+ if (!validator_.NoDuplicateAttributes(var->attributes)) {
return nullptr;
}
- if (!ValidateGlobalVariable(sem)) {
+ if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) {
return nullptr;
}
// TODO(bclayton): Call this at the end of resolve on all uniform and storage
// referenced structs
- if (!ValidateStorageClassLayout(sem, valid_type_storage_layouts_)) {
+ if (!validator_.StorageClassLayout(sem, valid_type_storage_layouts_)) {
return nullptr;
}
@@ -592,7 +593,7 @@
for (auto* attr : param->attributes) {
Mark(attr);
}
- if (!ValidateNoDuplicateAttributes(param->attributes)) {
+ if (!validator_.NoDuplicateAttributes(param->attributes)) {
return nullptr;
}
@@ -691,21 +692,21 @@
for (auto* attr : decl->attributes) {
Mark(attr);
}
- if (!ValidateNoDuplicateAttributes(decl->attributes)) {
+ if (!validator_.NoDuplicateAttributes(decl->attributes)) {
return nullptr;
}
for (auto* attr : decl->return_type_attributes) {
Mark(attr);
}
- if (!ValidateNoDuplicateAttributes(decl->return_type_attributes)) {
+ if (!validator_.NoDuplicateAttributes(decl->return_type_attributes)) {
return nullptr;
}
auto stage = current_function_
? current_function_->Declaration()->PipelineStage()
: ast::PipelineStage::kNone;
- if (!ValidateFunction(func, stage)) {
+ if (!validator_.Function(func, stage)) {
return nullptr;
}
@@ -809,7 +810,7 @@
<< "could not resolve constant workgroup_size constant value";
continue;
}
- // Validate and set the default value for this dimension.
+ // validator_.Validate and set the default value for this dimension.
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source);
return false;
@@ -843,7 +844,7 @@
current_statement_->Behaviors() = behaviors;
- if (!ValidateStatements(stmts)) {
+ if (!validator_.Statements(stmts)) {
return false;
}
@@ -958,7 +959,7 @@
sem->Behaviors().Add(sem::Behavior::kNext);
}
- return ValidateIfStatement(sem);
+ return validator_.IfStatement(sem);
});
}
@@ -989,7 +990,7 @@
}
sem->Behaviors().Add(body->Behaviors());
- return ValidateElseStatement(sem);
+ return validator_.ElseStatement(sem);
});
}
@@ -1039,7 +1040,7 @@
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
- return ValidateLoopStatement(sem);
+ return validator_.LoopStatement(sem);
});
});
}
@@ -1095,7 +1096,7 @@
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
- return ValidateForLoopStatement(sem);
+ return validator_.ForLoopStatement(sem);
});
}
@@ -1226,7 +1227,7 @@
sem->Behaviors() = inner->Behaviors();
- if (!ValidateBitcast(expr, ty)) {
+ if (!validator_.Bitcast(expr, ty)) {
return nullptr;
}
@@ -1316,7 +1317,7 @@
Mark(vec);
auto* v = builder_->create<sem::Vector>(
arg_el_ty, static_cast<uint32_t>(vec->width));
- if (!ValidateVector(v, vec->source)) {
+ if (!validator_.Vector(v, vec->source)) {
return nullptr;
}
builder_->Sem().Add(vec, v);
@@ -1337,7 +1338,7 @@
auto* column_type =
builder_->create<sem::Vector>(arg_el_ty, mat->rows);
auto* m = builder_->create<sem::Matrix>(column_type, mat->columns);
- if (!ValidateMatrix(m, mat->source)) {
+ if (!validator_.Matrix(m, mat->source)) {
return nullptr;
}
builder_->Sem().Add(mat, m);
@@ -1359,7 +1360,7 @@
auto* ident = expr->target.name;
Mark(ident);
- auto* resolved = ResolvedSymbol(ident);
+ auto* resolved = sem_.ResolvedSymbol(ident);
return Switch(
resolved, //
[&](sem::Type* type) { return type_ctor_or_conv(type); },
@@ -1414,7 +1415,7 @@
current_function_->AddDirectlyCalledBuiltin(builtin);
if (IsTextureBuiltin(builtin_type)) {
- if (!ValidateTextureBuiltinFunction(call)) {
+ if (!validator_.TextureBuiltinFunction(call)) {
return nullptr;
}
// Collect a texture/sampler pair for this builtin.
@@ -1436,7 +1437,7 @@
}
}
- if (!ValidateBuiltinCall(call)) {
+ if (!validator_.BuiltinCall(call)) {
return nullptr;
}
@@ -1500,7 +1501,7 @@
call->Behaviors() = arg_behaviors + target->Behaviors();
- if (!ValidateFunctionCall(call)) {
+ if (!validator_.FunctionCall(call, current_statement_)) {
return nullptr;
}
@@ -1527,23 +1528,23 @@
bool ok = Switch(
target,
[&](const sem::Vector* vec_type) {
- return ValidateVectorConstructorOrCast(expr, vec_type);
+ return validator_.VectorConstructorOrCast(expr, vec_type);
},
[&](const sem::Matrix* mat_type) {
// Note: Matrix types currently cannot be converted (the element
// type must only be f32). We implement this for the day we
// support other matrix element types.
- return ValidateMatrixConstructorOrCast(expr, mat_type);
+ return validator_.MatrixConstructorOrCast(expr, mat_type);
},
[&](const sem::Array* arr_type) {
- return ValidateArrayConstructorOrCast(expr, arr_type);
+ return validator_.ArrayConstructorOrCast(expr, arr_type);
},
[&](const sem::Struct* struct_type) {
- return ValidateStructureConstructorOrCast(expr, struct_type);
+ return validator_.StructureConstructorOrCast(expr, struct_type);
},
[&](Default) {
if (target->is_scalar()) {
- return ValidateScalarConstructorOrCast(expr, target);
+ return validator_.ScalarConstructorOrCast(expr, target);
}
AddError("type is not constructible", expr->source);
return false;
@@ -1593,20 +1594,20 @@
bool ok = Switch(
ty,
[&](const sem::Vector* vec_type) {
- return ValidateVectorConstructorOrCast(expr, vec_type);
+ return validator_.VectorConstructorOrCast(expr, vec_type);
},
[&](const sem::Matrix* mat_type) {
- return ValidateMatrixConstructorOrCast(expr, mat_type);
+ return validator_.MatrixConstructorOrCast(expr, mat_type);
},
[&](const sem::Array* arr_type) {
- return ValidateArrayConstructorOrCast(expr, arr_type);
+ return validator_.ArrayConstructorOrCast(expr, arr_type);
},
[&](const sem::Struct* struct_type) {
- return ValidateStructureConstructorOrCast(expr, struct_type);
+ return validator_.StructureConstructorOrCast(expr, struct_type);
},
[&](Default) {
if (ty->is_scalar()) {
- return ValidateScalarConstructorOrCast(expr, ty);
+ return validator_.ScalarConstructorOrCast(expr, ty);
}
AddError("type is not constructible", expr->source);
return false;
@@ -1652,7 +1653,7 @@
sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol;
- auto* resolved = ResolvedSymbol(expr);
+ auto* resolved = sem_.ResolvedSymbol(expr);
if (auto* var = As<sem::Variable>(resolved)) {
auto* user =
builder_->create<sem::VariableUser>(expr, current_statement_, var);
@@ -2156,7 +2157,8 @@
return nullptr;
}
- if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize()
+ if (!validator_.IsPlain(
+ elem_type)) { // Check must come before GetDefaultAlignAndSize()
AddError(sem_.TypeNameOf(elem_type) +
" cannot be used as an element type of an array",
source);
@@ -2166,7 +2168,7 @@
uint32_t el_align = elem_type->Align();
uint32_t el_size = elem_type->Size();
- if (!ValidateNoDuplicateAttributes(arr->attributes)) {
+ if (!validator_.NoDuplicateAttributes(arr->attributes)) {
return nullptr;
}
@@ -2176,7 +2178,7 @@
Mark(attr);
if (auto* sd = attr->As<ast::StrideAttribute>()) {
explicit_stride = sd->stride;
- if (!ValidateArrayStrideAttribute(sd, el_size, el_align, source)) {
+ if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) {
return nullptr;
}
continue;
@@ -2210,7 +2212,7 @@
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
- auto* var = ResolvedSymbol<sem::GlobalVariable>(ident);
+ auto* var = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
if (!var || !var->Declaration()->is_const) {
AddError("array size identifier must be a module-scope constant",
size_source);
@@ -2266,7 +2268,7 @@
elem_type, count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
- if (!ValidateArray(out, source)) {
+ if (!validator_.Array(out, source)) {
return nullptr;
}
@@ -2287,14 +2289,14 @@
if (!ty) {
return nullptr;
}
- if (!ValidateAlias(alias)) {
+ if (!validator_.Alias(alias)) {
return nullptr;
}
return ty;
}
sem::Struct* Resolver::Structure(const ast::Struct* str) {
- if (!ValidateNoDuplicateAttributes(str->attributes)) {
+ if (!validator_.NoDuplicateAttributes(str->attributes)) {
return nullptr;
}
for (auto* attr : str->attributes) {
@@ -2335,8 +2337,8 @@
return nullptr;
}
- // Validate member type
- if (!IsPlain(type)) {
+ // validator_.Validate member type
+ if (!validator_.IsPlain(type)) {
AddError(sem_.TypeNameOf(type) +
" cannot be used as the type of a structure member",
member->source);
@@ -2347,7 +2349,7 @@
uint64_t align = type->Align();
uint64_t size = type->Size();
- if (!ValidateNoDuplicateAttributes(member->attributes)) {
+ if (!validator_.NoDuplicateAttributes(member->attributes)) {
return nullptr;
}
@@ -2453,7 +2455,7 @@
auto stage = current_function_
? current_function_->Declaration()->PipelineStage()
: ast::PipelineStage::kNone;
- if (!ValidateStructure(out, stage)) {
+ if (!validator_.Structure(out, stage)) {
return nullptr;
}
@@ -2479,7 +2481,8 @@
// is available for validation.
auto* ret_type = stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef()
: builder_->create<sem::Void>();
- return ValidateReturn(stmt, current_function_->ReturnType(), ret_type);
+ return validator_.Return(stmt, current_function_->ReturnType(), ret_type,
+ current_statement_);
});
}
@@ -2510,7 +2513,7 @@
}
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
- return ValidateSwitch(stmt);
+ return validator_.SwitchStatement(stmt);
});
}
@@ -2542,7 +2545,7 @@
sem->Behaviors() = ctor->Behaviors();
}
- return ValidateVariable(var);
+ return validator_.Variable(var);
});
}
@@ -2567,7 +2570,7 @@
behaviors.Add(lhs->Behaviors());
}
- return ValidateAssignment(stmt, sem_.TypeOf(stmt->rhs));
+ return validator_.Assignment(stmt, sem_.TypeOf(stmt->rhs));
});
}
@@ -2577,7 +2580,7 @@
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kBreak;
- return ValidateBreakStatement(sem);
+ return validator_.BreakStatement(sem, current_statement_);
});
}
@@ -2620,7 +2623,7 @@
stmt->source);
return false;
}
- return ValidateAssignment(stmt, ty);
+ return validator_.Assignment(stmt, ty);
});
}
@@ -2639,7 +2642,7 @@
}
}
- return ValidateContinueStatement(sem);
+ return validator_.ContinueStatement(sem, current_statement_);
});
}
@@ -2650,7 +2653,7 @@
sem->Behaviors() = sem::Behavior::kDiscard;
current_function_->SetHasDiscard();
- return ValidateDiscardStatement(sem);
+ return validator_.DiscardStatement(sem, current_statement_);
});
}
@@ -2661,7 +2664,7 @@
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kFallthrough;
- return ValidateFallthroughStatement(sem);
+ return validator_.FallthroughStatement(sem);
});
}
@@ -2676,7 +2679,7 @@
}
sem->Behaviors() = lhs->Behaviors();
- return ValidateIncrementDecrementStatement(stmt);
+ return validator_.IncrementDecrementStatement(stmt);
});
}
@@ -2718,7 +2721,7 @@
sc, const_cast<sem::Type*>(arr->ElemType()), usage);
}
- if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
+ if (ast::IsHostShareable(sc) && !validator_.IsHostShareable(ty)) {
std::stringstream err;
err << "Type '" << sem_.TypeNameOf(ty)
<< "' cannot be used in storage class '" << sc
@@ -2782,62 +2785,6 @@
diagnostics_.add_note(diag::System::Resolver, msg, source);
}
-// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
-bool Resolver::IsPlain(const sem::Type* type) const {
- return type->is_scalar() ||
- type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
- sem::Struct>();
-}
-
-// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
-bool Resolver::IsFixedFootprint(const sem::Type* type) const {
- return Switch(
- type, //
- [&](const sem::Vector*) { return true; }, //
- [&](const sem::Matrix*) { return true; }, //
- [&](const sem::Atomic*) { return true; },
- [&](const sem::Array* arr) {
- return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
- },
- [&](const sem::Struct* str) {
- for (auto* member : str->Members()) {
- if (!IsFixedFootprint(member->Type())) {
- return false;
- }
- }
- return true;
- },
- [&](Default) { return type->is_scalar(); });
-}
-
-// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
-bool Resolver::IsStorable(const sem::Type* type) const {
- return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
-}
-
-// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
-bool Resolver::IsHostShareable(const sem::Type* type) const {
- if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
- return true;
- }
- return Switch(
- type, //
- [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
- [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
- [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
- [&](const sem::Struct* str) {
- for (auto* member : str->Members()) {
- if (!IsHostShareable(member->Type())) {
- return false;
- }
- }
- return true;
- },
- [&](const sem::Atomic* atomic) {
- return IsHostShareable(atomic->Type());
- });
-}
-
bool Resolver::IsBuiltin(Symbol symbol) const {
std::string name = builder_->Symbols().NameFor(symbol);
return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone;
@@ -2849,26 +2796,6 @@
[&](auto* stmt) { return stmt->expr == expr; });
}
-const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const {
- for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) {
- if (stop_at_loop && s->Is<sem::LoopStatement>()) {
- break;
- }
- if (s->Is<sem::LoopContinuingBlockStatement>()) {
- return s->Declaration();
- }
- if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
- if (f->Declaration()->continuing == s->Declaration()) {
- return s->Declaration();
- }
- if (stop_at_loop) {
- break;
- }
- }
- }
- return nullptr;
-}
-
////////////////////////////////////////////////////////////////////////////////
// Resolver::TypeConversionSig
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 95dd5ff..6487d35 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -16,7 +16,6 @@
#define SRC_TINT_RESOLVER_RESOLVER_H_
#include <memory>
-#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -27,13 +26,13 @@
#include "src/tint/program_builder.h"
#include "src/tint/resolver/dependency_graph.h"
#include "src/tint/resolver/sem_helper.h"
+#include "src/tint/resolver/validator.h"
#include "src/tint/scope_stack.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
-#include "src/tint/utils/map.h"
#include "src/tint/utils/unique_vector.h"
// Forward declarations
@@ -89,27 +88,31 @@
/// @param type the given type
/// @returns true if the given type is a plain type
- bool IsPlain(const sem::Type* type) const;
+ bool IsPlain(const sem::Type* type) const { return validator_.IsPlain(type); }
/// @param type the given type
/// @returns true if the given type is a fixed-footprint type
- bool IsFixedFootprint(const sem::Type* type) const;
+ bool IsFixedFootprint(const sem::Type* type) const {
+ return validator_.IsFixedFootprint(type);
+ }
/// @param type the given type
/// @returns true if the given type is storable
- bool IsStorable(const sem::Type* type) const;
+ bool IsStorable(const sem::Type* type) const {
+ return validator_.IsStorable(type);
+ }
/// @param type the given type
/// @returns true if the given type is host-shareable
- bool IsHostShareable(const sem::Type* type) const;
+ bool IsHostShareable(const sem::Type* type) const {
+ return validator_.IsHostShareable(type);
+ }
private:
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
- using ValidTypeStorageLayouts =
- std::set<std::pair<const sem::Type*, ast::StorageClass>>;
- ValidTypeStorageLayouts valid_type_storage_layouts_;
+ Validator::ValidTypeStorageLayouts valid_type_storage_layouts_;
/// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block.
@@ -237,106 +240,6 @@
const sem::Type* rhs_ty,
ast::BinaryOp op);
- // AST and Type validation methods
- // Each return true on success, false on failure.
- bool ValidatePipelineStages() const;
- bool ValidateAlias(const ast::Alias*) const;
- bool ValidateArray(const sem::Array* arr, const Source& source) const;
- bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
- uint32_t el_size,
- uint32_t el_align,
- const Source& source) const;
- bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const;
- bool ValidateAtomicVariable(const sem::Variable* var) const;
- bool ValidateAssignment(const ast::Statement* a,
- const sem::Type* rhs_ty) const;
- bool ValidateBitcast(const ast::BitcastExpression* cast,
- const sem::Type* to) const;
- bool ValidateBreakStatement(const sem::Statement* stmt) const;
- bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
- const sem::Type* storage_type,
- ast::PipelineStage stage,
- const bool is_input) const;
- bool ValidateContinueStatement(const sem::Statement* stmt) const;
- bool ValidateDiscardStatement(const sem::Statement* stmt) const;
- bool ValidateElseStatement(const sem::ElseStatement* stmt) const;
- bool ValidateEntryPoint(const sem::Function* func,
- ast::PipelineStage stage) const;
- bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const;
- bool ValidateFallthroughStatement(const sem::Statement* stmt) const;
- bool ValidateFunction(const sem::Function* func,
- ast::PipelineStage stage) const;
- bool ValidateFunctionCall(const sem::Call* call) const;
- bool ValidateGlobalVariable(const sem::Variable* var) const;
- bool ValidateIfStatement(const sem::IfStatement* stmt) const;
- bool ValidateIncrementDecrementStatement(
- const ast::IncrementDecrementStatement* stmt) const;
- bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
- const sem::Type* storage_type) const;
- bool ValidateBuiltinCall(const sem::Call* call) const;
- bool ValidateLocationAttribute(const ast::LocationAttribute* location,
- const sem::Type* type,
- std::unordered_set<uint32_t>& locations,
- ast::PipelineStage stage,
- const Source& source,
- const bool is_input = false) const;
- bool ValidateLoopStatement(const sem::LoopStatement* stmt) const;
- bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const;
- bool ValidateFunctionParameter(const ast::Function* func,
- const sem::Variable* var) const;
- bool ValidateReturn(const ast::ReturnStatement* ret,
- const sem::Type* func_type,
- const sem::Type* ret_type) const;
- bool ValidateStatements(const ast::StatementList& stmts) const;
- bool ValidateStorageTexture(const ast::StorageTexture* t) const;
- bool ValidateStructure(const sem::Struct* str,
- ast::PipelineStage stage) const;
- bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Struct* struct_type) const;
- bool ValidateSwitch(const ast::SwitchStatement* s);
- bool ValidateVariable(const sem::Variable* var) const;
- bool ValidateVariableConstructorOrCast(const ast::Variable* var,
- ast::StorageClass storage_class,
- const sem::Type* storage_type,
- const sem::Type* rhs_type) const;
- bool ValidateVector(const sem::Vector* ty, const Source& source) const;
- bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Vector* vec_type) const;
- bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Matrix* matrix_type) const;
- bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Type* type) const;
- bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Array* arr_type) const;
- bool ValidateTextureBuiltinFunction(const sem::Call* call) const;
- bool ValidateNoDuplicateAttributes(
- const ast::AttributeList& attributes) const;
- bool ValidateStorageClassLayout(const sem::Type* type,
- ast::StorageClass sc,
- Source source,
- ValidTypeStorageLayouts& layouts) const;
- bool ValidateStorageClassLayout(const sem::Variable* var,
- ValidTypeStorageLayouts& layouts) const;
-
- /// @returns true if the attribute list contains a
- /// ast::DisableValidationAttribute with the validation mode equal to
- /// `validation`
- bool IsValidationDisabled(const ast::AttributeList& attributes,
- ast::DisabledValidation validation) const;
-
- /// @returns true if the attribute list does not contains a
- /// ast::DisableValidationAttribute with the validation mode equal to
- /// `validation`
- bool IsValidationEnabled(const ast::AttributeList& attributes,
- ast::DisabledValidation validation) const;
-
- /// Returns a human-readable string representation of the vector type name
- /// with the given parameters.
- /// @param size the vector dimension
- /// @param element_type scalar vector sub-element type
- /// @return pretty string representation
- std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
-
/// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_
bool WorkgroupSize(const ast::Function*);
@@ -457,23 +360,6 @@
/// @returns true if `expr` is the current CallStatement's CallExpression
bool IsCallStatement(const ast::Expression* expr) const;
- /// Searches the current statement and up through parents of the current
- /// statement looking for a loop or for-loop continuing statement.
- /// @returns the closest continuing statement to the current statement that
- /// (transitively) owns the current statement.
- /// @param stop_at_loop if true then the function will return nullptr if a
- /// loop or for-loop was found before the continuing.
- const ast::Statement* ClosestContinuing(bool stop_at_loop) const;
-
- /// @returns the resolved symbol (function, type or variable) for the given
- /// ast::Identifier or ast::TypeName cast to the given semantic type.
- template <typename SEM = sem::Node>
- SEM* ResolvedSymbol(const ast::Node* node) const {
- auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
- return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
- : nullptr;
- }
-
struct TypeConversionSig {
const sem::Type* target;
const sem::Type* source;
@@ -511,6 +397,7 @@
std::unique_ptr<BuiltinTable> const builtin_table_;
DependencyGraph dependencies_;
SemHelper sem_;
+ Validator validator_;
std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
std::unordered_set<const ast::Node*> marked_;
diff --git a/src/tint/resolver/resolver_is_storeable_test.cc b/src/tint/resolver/resolver_is_storeable_test.cc
new file mode 100644
index 0000000..de180a3
--- /dev/null
+++ b/src/tint/resolver/resolver_is_storeable_test.cc
@@ -0,0 +1,79 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+
+namespace tint::resolver {
+namespace {
+
+using ResolverIsStorableTest = ResolverTest;
+
+TEST_F(ResolverIsStorableTest, Struct_AllMembersStorable) {
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIsStorableTest, Struct_SomeMembersNonStorable) {
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
+}
+
+TEST_F(ResolverIsStorableTest, Struct_NestedStorable) {
+ auto* storable = Structure("Storable", {
+ Member("a", ty.i32()),
+ Member("b", ty.f32()),
+ });
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.Of(storable)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverIsStorableTest, Struct_NestedNonStorable) {
+ auto* non_storable =
+ Structure("nonstorable",
+ {
+ Member("a", ty.i32()),
+ Member("b", ty.pointer<i32>(ast::StorageClass::kPrivate)),
+ });
+ Structure("S", {
+ Member("a", ty.i32()),
+ Member("b", ty.Of(non_storable)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: ptr<private, i32, read_write> cannot be used as the type of a structure member)");
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc
index 74b3c5b..57fff2d 100644
--- a/src/tint/resolver/sem_helper.cc
+++ b/src/tint/resolver/sem_helper.cc
@@ -18,7 +18,8 @@
namespace tint::resolver {
-SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {}
+SemHelper::SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies)
+ : builder_(builder), dependencies_(dependencies) {}
SemHelper::~SemHelper() = default;
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index 0e95397..58c2d57 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -19,6 +19,8 @@
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/program_builder.h"
+#include "src/tint/resolver/dependency_graph.h"
+#include "src/tint/utils/map.h"
namespace tint::resolver {
@@ -27,7 +29,8 @@
public:
/// Constructor
/// @param builder the program builder
- explicit SemHelper(ProgramBuilder* builder);
+ /// @param dependencies the program dependency graph
+ explicit SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies);
~SemHelper();
/// Get is a helper for obtaining the semantic node for the given AST node.
@@ -47,6 +50,16 @@
return const_cast<T*>(As<T>(sem));
}
+ /// @returns the resolved symbol (function, type or variable) for the given
+ /// ast::Identifier or ast::TypeName cast to the given semantic type.
+ /// @param node the node to retrieve
+ template <typename SEM = sem::Node>
+ SEM* ResolvedSymbol(const ast::Node* node) const {
+ auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
+ return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
+ : nullptr;
+ }
+
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr) const;
@@ -67,6 +80,7 @@
private:
ProgramBuilder* builder_;
+ DependencyGraph& dependencies_;
};
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/validator.cc
similarity index 85%
rename from src/tint/resolver/resolver_validation.cc
rename to src/tint/resolver/validator.cc
index 44e0a2f..ffe855e 100644
--- a/src/tint/resolver/resolver_validation.cc
+++ b/src/tint/resolver/validator.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/resolver/resolver.h"
+#include "src/tint/resolver/validator.h"
#include <algorithm>
#include <limits>
@@ -149,8 +149,104 @@
} // namespace
-bool Resolver::ValidateAtomic(const ast::Atomic* a,
- const sem::Atomic* s) const {
+Validator::Validator(ProgramBuilder* builder, SemHelper& sem)
+ : symbols_(builder->Symbols()),
+ diagnostics_(builder->Diagnostics()),
+ sem_(sem) {}
+
+Validator::~Validator() = default;
+
+void Validator::AddError(const std::string& msg, const Source& source) const {
+ diagnostics_.add_error(diag::System::Resolver, msg, source);
+}
+
+void Validator::AddWarning(const std::string& msg, const Source& source) const {
+ diagnostics_.add_warning(diag::System::Resolver, msg, source);
+}
+
+void Validator::AddNote(const std::string& msg, const Source& source) const {
+ diagnostics_.add_note(diag::System::Resolver, msg, source);
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
+bool Validator::IsPlain(const sem::Type* type) const {
+ return type->is_scalar() ||
+ type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
+ sem::Struct>();
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
+bool Validator::IsFixedFootprint(const sem::Type* type) const {
+ return Switch(
+ type, //
+ [&](const sem::Vector*) { return true; }, //
+ [&](const sem::Matrix*) { return true; }, //
+ [&](const sem::Atomic*) { return true; },
+ [&](const sem::Array* arr) {
+ return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
+ },
+ [&](const sem::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (!IsFixedFootprint(member->Type())) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](Default) { return type->is_scalar(); });
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
+bool Validator::IsHostShareable(const sem::Type* type) const {
+ if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
+ return true;
+ }
+ return Switch(
+ type, //
+ [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
+ [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
+ [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
+ [&](const sem::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (!IsHostShareable(member->Type())) {
+ return false;
+ }
+ }
+ return true;
+ },
+ [&](const sem::Atomic* atomic) {
+ return IsHostShareable(atomic->Type());
+ });
+}
+
+// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
+bool Validator::IsStorable(const sem::Type* type) const {
+ return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
+}
+
+const ast::Statement* Validator::ClosestContinuing(
+ bool stop_at_loop,
+ sem::Statement* current_statement) const {
+ for (const auto* s = current_statement; s != nullptr; s = s->Parent()) {
+ if (stop_at_loop && s->Is<sem::LoopStatement>()) {
+ break;
+ }
+ if (s->Is<sem::LoopContinuingBlockStatement>()) {
+ return s->Declaration();
+ }
+ if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
+ if (f->Declaration()->continuing == s->Declaration()) {
+ return s->Declaration();
+ }
+ if (stop_at_loop) {
+ break;
+ }
+ }
+ }
+ return nullptr;
+}
+
+bool Validator::Atomic(const ast::Atomic* a, const sem::Atomic* s) const {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// T must be either u32 or i32.
if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
@@ -161,7 +257,7 @@
return true;
}
-bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const {
+bool Validator::StorageTexture(const ast::StorageTexture* t) const {
switch (t->access) {
case ast::Access::kWrite:
break;
@@ -190,11 +286,10 @@
return true;
}
-bool Resolver::ValidateVariableConstructorOrCast(
- const ast::Variable* var,
- ast::StorageClass storage_class,
- const sem::Type* storage_ty,
- const sem::Type* rhs_ty) const {
+bool Validator::VariableConstructorOrCast(const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_ty,
+ const sem::Type* rhs_ty) const {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type
@@ -229,11 +324,10 @@
return true;
}
-bool Resolver::ValidateStorageClassLayout(
- const sem::Type* store_ty,
- ast::StorageClass sc,
- Source source,
- ValidTypeStorageLayouts& layouts) const {
+bool Validator::StorageClassLayout(const sem::Type* store_ty,
+ ast::StorageClass sc,
+ Source source,
+ ValidTypeStorageLayouts& layouts) const {
// https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
@@ -255,7 +349,7 @@
};
auto member_name_of = [this](const sem::StructMember* sm) {
- return builder_->Symbols().NameFor(sm->Declaration()->symbol);
+ return symbols_.NameFor(sm->Declaration()->symbol);
};
// Cache result of type + storage class pair.
@@ -273,9 +367,9 @@
uint32_t required_align = required_alignment_of(m->Type());
// Recurse into the member type.
- if (!ValidateStorageClassLayout(
- m->Type(), sc, m->Declaration()->type->source, layouts)) {
- AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ if (!StorageClassLayout(m->Type(), sc, m->Declaration()->type->source,
+ layouts)) {
+ AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source);
return false;
}
@@ -283,7 +377,7 @@
// Validate that member is at a valid byte offset
if (m->Offset() % required_align != 0) {
AddError("the offset of a struct member of type '" +
- m->Type()->UnwrapRef()->FriendlyName(builder_->Symbols()) +
+ m->Type()->UnwrapRef()->FriendlyName(symbols_) +
"' in storage class '" + ast::ToString(sc) +
"' must be a multiple of " +
std::to_string(required_align) + " bytes, but '" +
@@ -293,13 +387,13 @@
std::to_string(required_align) + ") on this member",
m->Declaration()->source);
- AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source);
if (auto* member_str = m->Type()->As<sem::Struct>()) {
- AddNote("and layout of struct member:\n" +
- member_str->Layout(builder_->Symbols()),
- member_str->Declaration()->source);
+ AddNote(
+ "and layout of struct member:\n" + member_str->Layout(symbols_),
+ member_str->Declaration()->source);
}
return false;
@@ -322,12 +416,12 @@
"'. Consider setting @align(16) on this member",
m->Declaration()->source);
- AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()),
+ AddNote("see layout of struct:\n" + str->Layout(symbols_),
str->Declaration()->source);
auto* prev_member_str = prev_member->Type()->As<sem::Struct>();
AddNote("and layout of previous member struct:\n" +
- prev_member_str->Layout(builder_->Symbols()),
+ prev_member_str->Layout(symbols_),
prev_member_str->Declaration()->source);
return false;
}
@@ -342,7 +436,7 @@
// TODO(crbug.com/tint/1388): Ideally we'd pass the source for nested
// element type here, but we can't easily get that from the semantic node.
// We should consider recursing through the AST type nodes instead.
- if (!ValidateStorageClassLayout(arr->ElemType(), sc, source, layouts)) {
+ if (!StorageClassLayout(arr->ElemType(), sc, source, layouts)) {
return false;
}
@@ -384,12 +478,11 @@
return true;
}
-bool Resolver::ValidateStorageClassLayout(
- const sem::Variable* var,
- ValidTypeStorageLayouts& layouts) const {
+bool Validator::StorageClassLayout(const sem::Variable* var,
+ ValidTypeStorageLayouts& layouts) const {
if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
- if (!ValidateStorageClassLayout(str, var->StorageClass(),
- str->Declaration()->source, layouts)) {
+ if (!StorageClassLayout(str, var->StorageClass(),
+ str->Declaration()->source, layouts)) {
AddNote("see declaration of variable", var->Declaration()->source);
return false;
}
@@ -398,8 +491,8 @@
if (var->Declaration()->type) {
source = var->Declaration()->type->source;
}
- if (!ValidateStorageClassLayout(var->Type()->UnwrapRef(),
- var->StorageClass(), source, layouts)) {
+ if (!StorageClassLayout(var->Type()->UnwrapRef(), var->StorageClass(),
+ source, layouts)) {
return false;
}
}
@@ -407,9 +500,13 @@
return true;
}
-bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const {
+bool Validator::GlobalVariable(
+ const sem::Variable* var,
+ std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
+ std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
+ const {
auto* decl = var->Declaration();
- if (!ValidateNoDuplicateAttributes(decl->attributes)) {
+ if (!NoDuplicateAttributes(decl->attributes)) {
return false;
}
@@ -417,8 +514,8 @@
if (decl->is_const) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) {
uint32_t id = id_attr->value;
- auto it = constant_ids_.find(id);
- if (it != constant_ids_.end() && it->second != var) {
+ auto it = constant_ids.find(id);
+ if (it != constant_ids.end() && it->second != var) {
AddError("pipeline constant IDs must be unique", attr->source);
AddNote("a pipeline constant with an ID of " + std::to_string(id) +
" was previously declared "
@@ -502,18 +599,21 @@
}
if (!decl->is_const) {
- if (!ValidateAtomicVariable(var)) {
+ if (!AtomicVariable(var, atomic_composite_info)) {
return false;
}
}
- return ValidateVariable(var);
+ return Variable(var);
}
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// Atomic types may only be instantiated by variables in the workgroup storage
// class or by storage buffer variables with a read_write access mode.
-bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const {
+bool Validator::AtomicVariable(
+ const sem::Variable* var,
+ std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
+ const {
auto sc = var->StorageClass();
auto* decl = var->Declaration();
auto access = var->Access();
@@ -529,8 +629,8 @@
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
- auto found = atomic_composite_info_.find(type);
- if (found != atomic_composite_info_.end()) {
+ auto found = atomic_composite_info.find(type);
+ if (found != atomic_composite_info.end()) {
if (sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kWorkgroup) {
AddError(
@@ -557,12 +657,12 @@
return true;
}
-bool Resolver::ValidateVariable(const sem::Variable* var) const {
+bool Validator::Variable(const sem::Variable* var) const {
auto* decl = var->Declaration();
auto* storage_ty = var->Type()->UnwrapRef();
if (var->Is<sem::GlobalVariable>()) {
- auto name = builder_->Symbols().NameFor(decl->symbol);
+ auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
auto* kind = var->Declaration()->is_const ? "let" : "var";
AddError(
@@ -634,9 +734,9 @@
return true;
}
-bool Resolver::ValidateFunctionParameter(const ast::Function* func,
- const sem::Variable* var) const {
- if (!ValidateVariable(var)) {
+bool Validator::FunctionParameter(const ast::Function* func,
+ const sem::Variable* var) const {
+ if (!Variable(var)) {
return false;
}
@@ -697,10 +797,10 @@
return true;
}
-bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
- const sem::Type* storage_ty,
- ast::PipelineStage stage,
- const bool is_input) const {
+bool Validator::BuiltinAttribute(const ast::BuiltinAttribute* attr,
+ const sem::Type* storage_ty,
+ ast::PipelineStage stage,
+ const bool is_input) const {
auto* type = storage_ty->UnwrapRef();
std::stringstream stage_name;
stage_name << stage;
@@ -816,9 +916,8 @@
return true;
}
-bool Resolver::ValidateInterpolateAttribute(
- const ast::InterpolateAttribute* attr,
- const sem::Type* storage_ty) const {
+bool Validator::InterpolateAttribute(const ast::InterpolateAttribute* attr,
+ const sem::Type* storage_ty) const {
auto* type = storage_ty->UnwrapRef();
if (type->is_integer_scalar_or_vector() &&
@@ -839,11 +938,11 @@
return true;
}
-bool Resolver::ValidateFunction(const sem::Function* func,
- ast::PipelineStage stage) const {
+bool Validator::Function(const sem::Function* func,
+ ast::PipelineStage stage) const {
auto* decl = func->Declaration();
- auto name = builder_->Symbols().NameFor(decl->symbol);
+ auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError(
"'" + name + "' is a builtin and cannot be redeclared as a function",
@@ -873,7 +972,7 @@
}
for (size_t i = 0; i < decl->params.size(); i++) {
- if (!ValidateFunctionParameter(decl, func->Parameters()[i])) {
+ if (!FunctionParameter(decl, func->Parameters()[i])) {
return false;
}
}
@@ -898,8 +997,7 @@
decl->attributes,
ast::DisabledValidation::kFunctionHasNoBody)) {
TINT_ICE(Resolver, diagnostics_)
- << "Function " << builder_->Symbols().NameFor(decl->symbol)
- << " has no body";
+ << "Function " << symbols_.NameFor(decl->symbol) << " has no body";
}
for (auto* attr : decl->return_type_attributes) {
@@ -925,7 +1023,7 @@
}
if (decl->IsEntryPoint()) {
- if (!ValidateEntryPoint(func, stage)) {
+ if (!EntryPoint(func, stage)) {
return false;
}
}
@@ -945,8 +1043,8 @@
return true;
}
-bool Resolver::ValidateEntryPoint(const sem::Function* func,
- ast::PipelineStage stage) const {
+bool Validator::EntryPoint(const sem::Function* func,
+ ast::PipelineStage stage) const {
auto* decl = func->Declaration();
// Use a lambda to validate the entry point attributes for a type.
@@ -994,7 +1092,7 @@
return false;
}
- if (!ValidateBuiltinAttribute(
+ if (!BuiltinAttribute(
builtin, ty, stage,
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false;
@@ -1011,14 +1109,14 @@
bool is_input = param_or_ret == ParamOrRetType::kParameter;
- if (!ValidateLocationAttribute(location, ty, locations, stage, source,
- is_input)) {
+ if (!LocationAttribute(location, ty, locations, stage, source,
+ is_input)) {
return false;
}
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_attribute = true;
- } else if (!ValidateInterpolateAttribute(interpolate, ty)) {
+ } else if (!InterpolateAttribute(interpolate, ty)) {
return false;
}
interpolate_attribute = interpolate;
@@ -1122,7 +1220,7 @@
member->Declaration()->source, param_or_ret,
/*is_struct_member*/ true)) {
AddNote("while analysing entry point '" +
- builder_->Symbols().NameFor(decl->symbol) + "'",
+ symbols_.NameFor(decl->symbol) + "'",
decl->source);
return false;
}
@@ -1206,7 +1304,7 @@
// variables in the resource interface of a given shader must not have
// the same group and binding values, when considered as a pair of
// values.
- auto func_name = builder_->Symbols().NameFor(decl->symbol);
+ auto func_name = symbols_.NameFor(decl->symbol);
AddError("entry point '" + func_name +
"' references multiple variables that use the "
"same resource binding @group(" +
@@ -1222,7 +1320,7 @@
return true;
}
-bool Resolver::ValidateStatements(const ast::StatementList& stmts) const {
+bool Validator::Statements(const ast::StatementList& stmts) const {
for (auto* stmt : stmts) {
if (!sem_.Get(stmt)->IsReachable()) {
/// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
@@ -1234,8 +1332,8 @@
return true;
}
-bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
- const sem::Type* to) const {
+bool Validator::Bitcast(const ast::BitcastExpression* cast,
+ const sem::Type* to) const {
auto* from = sem_.TypeOf(cast->expr)->UnwrapRef();
if (!from->is_numeric_scalar_or_vector()) {
AddError("'" + sem_.TypeNameOf(from) + "' cannot be bitcast",
@@ -1265,13 +1363,15 @@
return true;
}
-bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const {
+bool Validator::BreakStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const {
if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
return false;
}
- if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
+ if (auto* continuing =
+ ClosestContinuing(/*stop_at_loop*/ true, current_statement)) {
auto fail = [&](const char* note_msg, const Source& note_src) {
constexpr const char* kErrorMsg =
"break statement in a continuing block must be the single statement "
@@ -1332,8 +1432,10 @@
return true;
}
-bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const {
- if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
+bool Validator::ContinueStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const {
+ if (auto* continuing =
+ ClosestContinuing(/*stop_at_loop*/ true, current_statement)) {
AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source);
if (continuing != stmt->Declaration() &&
@@ -1352,8 +1454,10 @@
return true;
}
-bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const {
- if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+bool Validator::DiscardStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const {
+ if (auto* continuing =
+ ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source);
if (continuing != stmt->Declaration() &&
@@ -1365,7 +1469,7 @@
return true;
}
-bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const {
+bool Validator::FallthroughStatement(const sem::Statement* stmt) const {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) {
@@ -1388,7 +1492,7 @@
return false;
}
-bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const {
+bool Validator::ElseStatement(const sem::ElseStatement* stmt) const {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
@@ -1401,7 +1505,7 @@
return true;
}
-bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const {
+bool Validator::LoopStatement(const sem::LoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) {
AddError("loop does not exit", stmt->Declaration()->source.Begin());
return false;
@@ -1409,8 +1513,7 @@
return true;
}
-bool Resolver::ValidateForLoopStatement(
- const sem::ForLoopStatement* stmt) const {
+bool Validator::ForLoopStatement(const sem::ForLoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) {
AddError("for-loop does not exit", stmt->Declaration()->source.Begin());
return false;
@@ -1427,7 +1530,7 @@
return true;
}
-bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const {
+bool Validator::IfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError(
@@ -1438,7 +1541,7 @@
return true;
}
-bool Resolver::ValidateBuiltinCall(const sem::Call* call) const {
+bool Validator::BuiltinCall(const sem::Call* call) const {
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
@@ -1451,7 +1554,7 @@
// If the called function does not return a value, a function call
// statement should be used instead.
auto* ident = call->Declaration()->target.name;
- auto name = builder_->Symbols().NameFor(ident->symbol);
+ auto name = symbols_.NameFor(ident->symbol);
AddError("builtin '" + name + "' does not return a value",
call->Declaration()->source);
return false;
@@ -1461,7 +1564,7 @@
return true;
}
-bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const {
+bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin) {
return false;
@@ -1533,11 +1636,12 @@
check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
}
-bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
+bool Validator::FunctionCall(const sem::Call* call,
+ sem::Statement* current_statement) const {
auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>();
auto sym = decl->target.name->symbol;
- auto name = builder_->Symbols().NameFor(sym);
+ auto name = symbols_.NameFor(sym);
if (target->Declaration()->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction
@@ -1575,7 +1679,7 @@
if (param_type->Is<sem::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
- auto* var = ResolvedSymbol<sem::Variable>(ident_expr);
+ auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr);
if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
@@ -1587,7 +1691,7 @@
if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary =
unary->expr->As<ast::IdentifierExpression>()) {
- auto* var = ResolvedSymbol<sem::Variable>(ident_unary);
+ auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary);
if (!var) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier";
@@ -1634,7 +1738,8 @@
}
if (call->Behaviors().Contains(sem::Behavior::kDiscard)) {
- if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+ if (auto* continuing =
+ ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError(
"cannot call a function that may discard inside a continuing block",
call->Declaration()->source);
@@ -1649,7 +1754,7 @@
return true;
}
-bool Resolver::ValidateStructureConstructorOrCast(
+bool Validator::StructureConstructorOrCast(
const ast::CallExpression* ctor,
const sem::Struct* struct_type) const {
if (!struct_type->IsConstructible()) {
@@ -1684,9 +1789,8 @@
return true;
}
-bool Resolver::ValidateArrayConstructorOrCast(
- const ast::CallExpression* ctor,
- const sem::Array* array_type) const {
+bool Validator::ArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* array_type) const {
auto& values = ctor->args;
auto* elem_ty = array_type->ElemType();
for (auto* value : values) {
@@ -1726,9 +1830,8 @@
return true;
}
-bool Resolver::ValidateVectorConstructorOrCast(
- const ast::CallExpression* ctor,
- const sem::Vector* vec_type) const {
+bool Validator::VectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type) const {
auto& values = ctor->args;
auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0;
@@ -1790,8 +1893,7 @@
return true;
}
-bool Resolver::ValidateVector(const sem::Vector* ty,
- const Source& source) const {
+bool Validator::Vector(const sem::Vector* ty, const Source& source) const {
if (!ty->type()->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
source);
@@ -1800,8 +1902,7 @@
return true;
}
-bool Resolver::ValidateMatrix(const sem::Matrix* ty,
- const Source& source) const {
+bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const {
if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source);
return false;
@@ -1809,16 +1910,15 @@
return true;
}
-bool Resolver::ValidateMatrixConstructorOrCast(
- const ast::CallExpression* ctor,
- const sem::Matrix* matrix_ty) const {
+bool Validator::MatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_ty) const {
auto& values = ctor->args;
// Zero Value expression
if (values.empty()) {
return true;
}
- if (!ValidateMatrix(matrix_ty, ctor->source)) {
+ if (!Matrix(matrix_ty, ctor->source)) {
return false;
}
@@ -1844,7 +1944,7 @@
if (i > 0) {
ss << ", ";
}
- ss << arg_tys[i]->FriendlyName(builder_->Symbols());
+ ss << arg_tys[i]->FriendlyName(symbols_);
}
ss << ")" << std::endl << std::endl;
ss << "3 candidates available:" << std::endl;
@@ -1885,8 +1985,8 @@
return true;
}
-bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
- const sem::Type* ty) const {
+bool Validator::ScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* ty) const {
if (ctor->args.size() == 0) {
return true;
}
@@ -1921,7 +2021,8 @@
return true;
}
-bool Resolver::ValidatePipelineStages() const {
+bool Validator::PipelineStages(
+ const std::vector<sem::Function*>& entry_points) const {
auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage();
@@ -1940,17 +2041,14 @@
}
AddNote("variable is declared here", var->Declaration()->source);
if (func != entry_point) {
- TraverseCallChain(diagnostics_, entry_point, func,
- [&](const sem::Function* f) {
- AddNote("called by function '" +
- builder_->Symbols().NameFor(
- f->Declaration()->symbol) +
- "'",
- f->Declaration()->source);
- });
+ TraverseCallChain(
+ diagnostics_, entry_point, func, [&](const sem::Function* f) {
+ AddNote("called by function '" +
+ symbols_.NameFor(f->Declaration()->symbol) + "'",
+ f->Declaration()->source);
+ });
AddNote("called by entry point '" +
- builder_->Symbols().NameFor(
- entry_point->Declaration()->symbol) +
+ symbols_.NameFor(entry_point->Declaration()->symbol) +
"'",
entry_point->Declaration()->source);
}
@@ -1961,7 +2059,7 @@
return true;
};
- for (auto* entry_point : entry_points_) {
+ for (auto* entry_point : entry_points) {
if (!check_workgroup_storage(entry_point, entry_point)) {
return false;
}
@@ -1985,15 +2083,12 @@
if (func != entry_point) {
TraverseCallChain(
diagnostics_, entry_point, func, [&](const sem::Function* f) {
- AddNote(
- "called by function '" +
- builder_->Symbols().NameFor(f->Declaration()->symbol) +
- "'",
- f->Declaration()->source);
+ AddNote("called by function '" +
+ symbols_.NameFor(f->Declaration()->symbol) + "'",
+ f->Declaration()->source);
});
AddNote("called by entry point '" +
- builder_->Symbols().NameFor(
- entry_point->Declaration()->symbol) +
+ symbols_.NameFor(entry_point->Declaration()->symbol) +
"'",
entry_point->Declaration()->source);
}
@@ -2003,7 +2098,7 @@
return true;
};
- for (auto* entry_point : entry_points_) {
+ for (auto* entry_point : entry_points) {
if (!check_builtin_calls(entry_point, entry_point)) {
return false;
}
@@ -2016,8 +2111,7 @@
return true;
}
-bool Resolver::ValidateArray(const sem::Array* arr,
- const Source& source) const {
+bool Validator::Array(const sem::Array* arr, const Source& source) const {
auto* el_ty = arr->ElemType();
if (!IsFixedFootprint(el_ty)) {
@@ -2028,10 +2122,10 @@
return true;
}
-bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
- uint32_t el_size,
- uint32_t el_align,
- const Source& source) const {
+bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr,
+ uint32_t el_size,
+ uint32_t el_align,
+ const Source& source) const {
auto stride = attr->stride;
bool is_valid_stride =
(stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
@@ -2050,8 +2144,8 @@
return true;
}
-bool Resolver::ValidateAlias(const ast::Alias* alias) const {
- auto name = builder_->Symbols().NameFor(alias->name);
+bool Validator::Alias(const ast::Alias* alias) const {
+ auto name = symbols_.NameFor(alias->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
alias->source);
@@ -2061,9 +2155,9 @@
return true;
}
-bool Resolver::ValidateStructure(const sem::Struct* str,
- ast::PipelineStage stage) const {
- auto name = builder_->Symbols().NameFor(str->Declaration()->name);
+bool Validator::Structure(const sem::Struct* str,
+ ast::PipelineStage stage) const {
+ auto name = symbols_.NameFor(str->Declaration()->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
str->Declaration()->source);
@@ -2122,13 +2216,13 @@
invariant_attribute = invariant;
} else if (auto* location = attr->As<ast::LocationAttribute>()) {
has_location = true;
- if (!ValidateLocationAttribute(location, member->Type(), locations,
- stage, member->Declaration()->source)) {
+ if (!LocationAttribute(location, member->Type(), locations, stage,
+ member->Declaration()->source)) {
return false;
}
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
- if (!ValidateBuiltinAttribute(builtin, member->Type(), stage,
- /* is_input */ false)) {
+ if (!BuiltinAttribute(builtin, member->Type(), stage,
+ /* is_input */ false)) {
return false;
}
if (builtin->builtin == ast::Builtin::kPosition) {
@@ -2136,7 +2230,7 @@
}
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
interpolate_attribute = interpolate;
- if (!ValidateInterpolateAttribute(interpolate, member->Type())) {
+ if (!InterpolateAttribute(interpolate, member->Type())) {
return false;
}
}
@@ -2165,13 +2259,12 @@
return true;
}
-bool Resolver::ValidateLocationAttribute(
- const ast::LocationAttribute* location,
- const sem::Type* type,
- std::unordered_set<uint32_t>& locations,
- ast::PipelineStage stage,
- const Source& source,
- const bool is_input) const {
+bool Validator::LocationAttribute(const ast::LocationAttribute* location,
+ const sem::Type* type,
+ std::unordered_set<uint32_t>& locations,
+ ast::PipelineStage stage,
+ const Source& source,
+ const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) {
AddError("attribute is not valid for compute shader " + inputs_or_output,
@@ -2201,9 +2294,10 @@
return true;
}
-bool Resolver::ValidateReturn(const ast::ReturnStatement* ret,
- const sem::Type* func_type,
- const sem::Type* ret_type) const {
+bool Validator::Return(const ast::ReturnStatement* ret,
+ const sem::Type* func_type,
+ const sem::Type* ret_type,
+ sem::Statement* current_statement) const {
if (func_type->UnwrapRef() != ret_type) {
AddError(
"return statement type must match its function "
@@ -2215,7 +2309,8 @@
}
auto* sem = sem_.Get(ret);
- if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
+ if (auto* continuing =
+ ClosestContinuing(/*stop_at_loop*/ false, current_statement)) {
AddError("continuing blocks must not contain a return statement",
ret->source);
if (continuing != sem->Declaration() &&
@@ -2228,7 +2323,7 @@
return true;
}
-bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
+bool Validator::SwitchStatement(const ast::SwitchStatement* s) {
auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef();
if (!cond_ty->is_integer_scalar()) {
AddError(
@@ -2284,8 +2379,8 @@
return true;
}
-bool Resolver::ValidateAssignment(const ast::Statement* a,
- const sem::Type* rhs_ty) const {
+bool Validator::Assignment(const ast::Statement* a,
+ const sem::Type* rhs_ty) const {
const ast::Expression* lhs;
const ast::Expression* rhs;
if (auto* assign = a->As<ast::AssignmentStatement>()) {
@@ -2317,19 +2412,17 @@
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
auto const* lhs_ty = sem_.TypeOf(lhs);
- if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) {
+ if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) {
AddError("cannot assign to function parameter", lhs->source);
- AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
- "' is declared here:",
+ AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
decl->source);
return false;
}
if (decl->is_const) {
AddError("cannot assign to const", lhs->source);
- AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
- "' is declared here:",
+ AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
decl->source);
return false;
}
@@ -2366,25 +2459,23 @@
return true;
}
-bool Resolver::ValidateIncrementDecrementStatement(
+bool Validator::IncrementDecrementStatement(
const ast::IncrementDecrementStatement* inc) const {
const ast::Expression* lhs = inc->lhs;
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
- if (auto* var = ResolvedSymbol<sem::Variable>(lhs)) {
+ if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) {
AddError("cannot modify function parameter", lhs->source);
- AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
- "' is declared here:",
+ AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
decl->source);
return false;
}
if (decl->is_const) {
AddError("cannot modify constant value", lhs->source);
- AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
- "' is declared here:",
+ AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:",
decl->source);
return false;
}
@@ -2415,7 +2506,7 @@
return true;
}
-bool Resolver::ValidateNoDuplicateAttributes(
+bool Validator::NoDuplicateAttributes(
const ast::AttributeList& attributes) const {
std::unordered_map<const TypeInfo*, Source> seen;
for (auto* d : attributes) {
@@ -2429,8 +2520,8 @@
return true;
}
-bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes,
- ast::DisabledValidation validation) const {
+bool Validator::IsValidationDisabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const {
for (auto* attribute : attributes) {
if (auto* dv = attribute->As<ast::DisableValidationAttribute>()) {
if (dv->validation == validation) {
@@ -2441,15 +2532,15 @@
return false;
}
-bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes,
- ast::DisabledValidation validation) const {
+bool Validator::IsValidationEnabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const {
return !IsValidationDisabled(attributes, validation);
}
-std::string Resolver::VectorPretty(uint32_t size,
- const sem::Type* element_type) const {
+std::string Validator::VectorPretty(uint32_t size,
+ const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size);
- return vec_type.FriendlyName(builder_->Symbols());
+ return vec_type.FriendlyName(symbols_);
}
} // namespace tint::resolver
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
new file mode 100644
index 0000000..146f7ba
--- /dev/null
+++ b/src/tint/resolver/validator.h
@@ -0,0 +1,457 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_RESOLVER_VALIDATOR_H_
+#define SRC_TINT_RESOLVER_VALIDATOR_H_
+
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "src/tint/ast/pipeline_stage.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/resolver/sem_helper.h"
+#include "src/tint/source.h"
+
+// Forward declarations
+namespace tint::ast {
+class IndexAccessorExpression;
+class BinaryExpression;
+class BitcastExpression;
+class CallExpression;
+class CallStatement;
+class CaseStatement;
+class ForLoopStatement;
+class Function;
+class IdentifierExpression;
+class LoopStatement;
+class MemberAccessorExpression;
+class ReturnStatement;
+class SwitchStatement;
+class UnaryOpExpression;
+class Variable;
+} // namespace tint::ast
+namespace tint::sem {
+class Array;
+class Atomic;
+class BlockStatement;
+class Builtin;
+class CaseStatement;
+class ElseStatement;
+class ForLoopStatement;
+class IfStatement;
+class LoopStatement;
+class Statement;
+class SwitchStatement;
+class TypeConstructor;
+} // namespace tint::sem
+
+namespace tint::resolver {
+
+/// Validation logic for various ast nodes. The validations in general should
+/// be shallow and depend on the resolver to call on children. The validations
+/// also assume that sem changes have already been made. The validation checks
+/// should not alter the AST or SEM trees.
+class Validator {
+ public:
+ /// The valid type storage layouts typedef
+ using ValidTypeStorageLayouts =
+ std::set<std::pair<const sem::Type*, ast::StorageClass>>;
+
+ /// Constructor
+ /// @param builder the program builder
+ /// @param helper the SEM helper to validate with
+ Validator(ProgramBuilder* builder, SemHelper& helper);
+ ~Validator();
+
+ /// Adds the given error message to the diagnostics
+ /// @param msg the error message
+ /// @param source the error source
+ void AddError(const std::string& msg, const Source& source) const;
+
+ /// Adds the given warning message to the diagnostics
+ /// @param msg the warning message
+ /// @param source the warning source
+ void AddWarning(const std::string& msg, const Source& source) const;
+
+ /// Adds the given note message to the diagnostics
+ /// @param msg the note message
+ /// @param source the note source
+ void AddNote(const std::string& msg, const Source& source) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is a plain type
+ bool IsPlain(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is a fixed-footprint type
+ bool IsFixedFootprint(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is storable
+ bool IsStorable(const sem::Type* type) const;
+
+ /// @param type the given type
+ /// @returns true if the given type is host-shareable
+ bool IsHostShareable(const sem::Type* type) const;
+
+ /// Validates pipeline stages
+ /// @param entry_points the entry points to the module
+ /// @returns true on success, false otherwise.
+ bool PipelineStages(const std::vector<sem::Function*>& entry_points) const;
+
+ /// Validates aliases
+ /// @param alias the alias to validate
+ /// @returns true on success, false otherwise.
+ bool Alias(const ast::Alias* alias) const;
+
+ /// Validates the array
+ /// @param arr the array to validate
+ /// @param source the source of the array
+ /// @returns true on success, false otherwise.
+ bool Array(const sem::Array* arr, const Source& source) const;
+
+ /// Validates an array stride attribute
+ /// @param attr the stride attribute to validate
+ /// @param el_size the element size
+ /// @param el_align the element alignment
+ /// @param source the source of the attribute
+ /// @returns true on success, false otherwise
+ bool ArrayStrideAttribute(const ast::StrideAttribute* attr,
+ uint32_t el_size,
+ uint32_t el_align,
+ const Source& source) const;
+
+ /// Validates an atomic
+ /// @param a the atomic ast node to validate
+ /// @param s the atomic sem node
+ /// @returns true on success, false otherwise.
+ bool Atomic(const ast::Atomic* a, const sem::Atomic* s) const;
+
+ /// Validates an atoic variable
+ /// @param var the variable to validate
+ /// @param atomic_composite_info store atomic information
+ /// @returns true on success, false otherwise.
+ bool AtomicVariable(const sem::Variable* var,
+ std::unordered_map<const sem::Type*, const Source&>
+ atomic_composite_info) const;
+
+ /// Validates an assignment
+ /// @param a the assignment statement
+ /// @param rhs_ty the type of the right hand side
+ /// @returns true on success, false otherwise.
+ bool Assignment(const ast::Statement* a, const sem::Type* rhs_ty) const;
+
+ /// Validates a bitcase
+ /// @param cast the bitcast expression
+ /// @param to the destination type
+ /// @returns true on success, false otherwise
+ bool Bitcast(const ast::BitcastExpression* cast, const sem::Type* to) const;
+
+ /// Validates a break statement
+ /// @param stmt the break statement to validate
+ /// @param current_statement the current statement being resolved
+ /// @returns true on success, false otherwise.
+ bool BreakStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const;
+
+ /// Validates a builtin attribute
+ /// @param attr the attribute to validate
+ /// @param storage_type the attribute storage type
+ /// @param stage the current pipeline stage
+ /// @param is_input true if this is an input attribute
+ /// @returns true on success, false otherwise.
+ bool BuiltinAttribute(const ast::BuiltinAttribute* attr,
+ const sem::Type* storage_type,
+ ast::PipelineStage stage,
+ const bool is_input) const;
+
+ /// Validates a continue statement
+ /// @param stmt the continue statement to validate
+ /// @param current_statement the current statement being resolved
+ /// @returns true on success, false otherwise
+ bool ContinueStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const;
+
+ /// Validates a discard statement
+ /// @param stmt the statement to validate
+ /// @param current_statement the current statement being resolved
+ /// @returns true on success, false otherwise
+ bool DiscardStatement(const sem::Statement* stmt,
+ sem::Statement* current_statement) const;
+
+ /// Validates an else statement
+ /// @param stmt the else statement to validate
+ /// @returns true on success, false otherwise
+ bool ElseStatement(const sem::ElseStatement* stmt) const;
+
+ /// Validates an entry point
+ /// @param func the entry point function to validate
+ /// @param stage the pipeline stage for the entry point
+ /// @returns true on success, false otherwise
+ bool EntryPoint(const sem::Function* func, ast::PipelineStage stage) const;
+
+ /// Validates a for loop
+ /// @param stmt the for loop statement to validate
+ /// @returns true on success, false otherwise
+ bool ForLoopStatement(const sem::ForLoopStatement* stmt) const;
+
+ /// Validates a fallthrough statement
+ /// @param stmt the fallthrough to validate
+ /// @returns true on success, false otherwise
+ bool FallthroughStatement(const sem::Statement* stmt) const;
+
+ /// Validates a function
+ /// @param func the function to validate
+ /// @param stage the current pipeline stage
+ /// @returns true on success, false otherwise.
+ bool Function(const sem::Function* func, ast::PipelineStage stage) const;
+
+ /// Validates a function call
+ /// @param call the function call to validate
+ /// @param current_statement the current statement being resolved
+ /// @returns true on success, false otherwise
+ bool FunctionCall(const sem::Call* call,
+ sem::Statement* current_statement) const;
+
+ /// Validates a global variable
+ /// @param var the global variable to validate
+ /// @param constant_ids the set of constant ids in the module
+ /// @param atomic_composite_info atomic composite info in the module
+ /// @returns true on success, false otherwise
+ bool GlobalVariable(
+ const sem::Variable* var,
+ std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
+ std::unordered_map<const sem::Type*, const Source&> atomic_composite_info)
+ const;
+
+ /// Validates an if statement
+ /// @param stmt the statement to validate
+ /// @returns true on success, false otherwise
+ bool IfStatement(const sem::IfStatement* stmt) const;
+
+ /// Validates an increment or decrement statement
+ /// @param stmt the statement to validate
+ /// @returns true on success, false otherwise
+ bool IncrementDecrementStatement(
+ const ast::IncrementDecrementStatement* stmt) const;
+
+ /// Validates an interpolate attribute
+ /// @param attr the interpolation attribute to validate
+ /// @param storage_type the storage type of the attached variable
+ /// @returns true on succes, false otherwise
+ bool InterpolateAttribute(const ast::InterpolateAttribute* attr,
+ const sem::Type* storage_type) const;
+
+ /// Validates a builtin call
+ /// @param call the builtin call to validate
+ /// @returns true on success, false otherwise.
+ bool BuiltinCall(const sem::Call* call) const;
+
+ /// Validates a location attribute
+ /// @param location the location attribute to validate
+ /// @param type the variable type
+ /// @param locations the set of locations in the module
+ /// @param stage the current pipeline stage
+ /// @param source the source of the attribute
+ /// @param is_input true if this is an input variable
+ /// @returns true on success, false otherwise.
+ bool LocationAttribute(const ast::LocationAttribute* location,
+ const sem::Type* type,
+ std::unordered_set<uint32_t>& locations,
+ ast::PipelineStage stage,
+ const Source& source,
+ const bool is_input = false) const;
+
+ /// Validates a loop statement
+ /// @param stmt the loop statement
+ /// @returns true on success, false otherwise.
+ bool LoopStatement(const sem::LoopStatement* stmt) const;
+
+ /// Validates a matrix
+ /// @param ty the matrix to validate
+ /// @param source the source of the matrix
+ /// @returns true on success, false otherwise
+ bool Matrix(const sem::Matrix* ty, const Source& source) const;
+
+ /// Validates a function parameter
+ /// @param func the function the variable is for
+ /// @param var the variable to validate
+ /// @returns true on success, false otherwise
+ bool FunctionParameter(const ast::Function* func,
+ const sem::Variable* var) const;
+
+ /// Validates a return
+ /// @param ret the return statement to validate
+ /// @param func_type the return type of the curreunt function
+ /// @param ret_type the return type
+ /// @param current_statement the current statement being resolved
+ /// @returns true on success, false otherwise
+ bool Return(const ast::ReturnStatement* ret,
+ const sem::Type* func_type,
+ const sem::Type* ret_type,
+ sem::Statement* current_statement) const;
+
+ /// Validates a list of statements
+ /// @param stmts the statements to validate
+ /// @returns true on success, false otherwise
+ bool Statements(const ast::StatementList& stmts) const;
+
+ /// Validates a storage texture
+ /// @param t the texture to validate
+ /// @returns true on success, false otherwise
+ bool StorageTexture(const ast::StorageTexture* t) const;
+
+ /// Validates a structure
+ /// @param str the structure to validate
+ /// @param stage the current pipeline stage
+ /// @returns true on success, false otherwise.
+ bool Structure(const sem::Struct* str, ast::PipelineStage stage) const;
+
+ /// Validates a structure constructor or cast
+ /// @param ctor the call expression to validate
+ /// @param struct_type the type of the structure
+ /// @returns true on success, false otherwise
+ bool StructureConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Struct* struct_type) const;
+
+ /// Validates a switch statement
+ /// @param s the switch to validate
+ /// @returns true on success, false otherwise
+ bool SwitchStatement(const ast::SwitchStatement* s);
+
+ /// Validates a variable
+ /// @param var the variable to validate
+ /// @returns true on success, false otherwise.
+ bool Variable(const sem::Variable* var) const;
+
+ /// Validates a variable constructor or cast
+ /// @param var the variable to validate
+ /// @param storage_class the storage class of the variable
+ /// @param storage_type the type of the storage
+ /// @param rhs_type the right hand side of the expression
+ /// @returns true on succes, false otherwise
+ bool VariableConstructorOrCast(const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_type,
+ const sem::Type* rhs_type) const;
+
+ /// Validates a vector
+ /// @param ty the vector to validate
+ /// @param source the source of the vector
+ /// @returns true on success, false otherwise
+ bool Vector(const sem::Vector* ty, const Source& source) const;
+
+ /// Validates a vector constructor or cast
+ /// @param ctor the call expression to validate
+ /// @param vec_type the vector type
+ /// @returns true on success, false otherwise
+ bool VectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type) const;
+
+ /// Validates a matrix constructor or cast
+ /// @param ctor the call expression to validate
+ /// @param matrix_type the type of the matrix
+ /// @returns true on success, false otherwise
+ bool MatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_type) const;
+
+ /// Validates a scalar constructor or cast
+ /// @param ctor the call expression to validate
+ /// @param type the type of the scalar
+ /// @returns true on success, false otherwise.
+ bool ScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* type) const;
+
+ /// Validates an array constructor or cast
+ /// @param ctor the call expresion to validate
+ /// @param arr_type the type of the array
+ /// @returns true on success, false otherwise
+ bool ArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* arr_type) const;
+
+ /// Validates a texture builtin function
+ /// @param call the builtin call to validate
+ /// @returns true on success, false otherwise
+ bool TextureBuiltinFunction(const sem::Call* call) const;
+
+ /// Validates there are no duplicate attributes
+ /// @param attributes the list of attributes to validate
+ /// @returns true on success, false otherwise.
+ bool NoDuplicateAttributes(const ast::AttributeList& attributes) const;
+
+ /// Validates a storage class layout
+ /// @param type the type to validate
+ /// @param sc the storage class
+ /// @param source the source of the type
+ /// @param layouts previously validated storage layouts
+ /// @returns true on success, false otherwise
+ bool StorageClassLayout(const sem::Type* type,
+ ast::StorageClass sc,
+ Source source,
+ ValidTypeStorageLayouts& layouts) const;
+
+ /// Validates a storage class layout
+ /// @param var the variable to validate
+ /// @param layouts previously validated storage layouts
+ /// @returns true on success, false otherwise.
+ bool StorageClassLayout(const sem::Variable* var,
+ ValidTypeStorageLayouts& layouts) const;
+
+ /// @returns true if the attribute list contains a
+ /// ast::DisableValidationAttribute with the validation mode equal to
+ /// `validation`
+ /// @param attributes the attribute list to check
+ /// @param validation the validation mode to check
+ bool IsValidationDisabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const;
+
+ /// @returns true if the attribute list does not contains a
+ /// ast::DisableValidationAttribute with the validation mode equal to
+ /// `validation`
+ /// @param attributes the attribute list to check
+ /// @param validation the validation mode to check
+ bool IsValidationEnabled(const ast::AttributeList& attributes,
+ ast::DisabledValidation validation) const;
+
+ private:
+ /// Searches the current statement and up through parents of the current
+ /// statement looking for a loop or for-loop continuing statement.
+ /// @returns the closest continuing statement to the current statement that
+ /// (transitively) owns the current statement.
+ /// @param stop_at_loop if true then the function will return nullptr if a
+ /// loop or for-loop was found before the continuing.
+ /// @param current_statement the current statement being resolved
+ const ast::Statement* ClosestContinuing(
+ bool stop_at_loop,
+ sem::Statement* current_statement) const;
+
+ /// Returns a human-readable string representation of the vector type name
+ /// with the given parameters.
+ /// @param size the vector dimension
+ /// @param element_type scalar vector sub-element type
+ /// @return pretty string representation
+ std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
+
+ SymbolTable& symbols_;
+ diag::List& diagnostics_;
+ SemHelper& sem_;
+};
+
+} // namespace tint::resolver
+
+#endif // SRC_TINT_RESOLVER_VALIDATOR_H_
diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc
new file mode 100644
index 0000000..f180936
--- /dev/null
+++ b/src/tint/resolver/validator_is_storeable_test.cc
@@ -0,0 +1,86 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/validator.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/validator_test_helper.h"
+#include "src/tint/sem/atomic_type.h"
+
+namespace tint::resolver {
+namespace {
+
+using ValidatorIsStorableTest = ValidatorTest;
+
+TEST_F(ValidatorIsStorableTest, Void) {
+ EXPECT_FALSE(v()->IsStorable(create<sem::Void>()));
+}
+
+TEST_F(ValidatorIsStorableTest, Scalar) {
+ EXPECT_TRUE(v()->IsStorable(create<sem::Bool>()));
+ EXPECT_TRUE(v()->IsStorable(create<sem::I32>()));
+ EXPECT_TRUE(v()->IsStorable(create<sem::U32>()));
+ EXPECT_TRUE(v()->IsStorable(create<sem::F32>()));
+}
+
+TEST_F(ValidatorIsStorableTest, Vector) {
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::I32>(), 4u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::U32>(), 4u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Vector>(create<sem::F32>(), 4u)));
+}
+
+TEST_F(ValidatorIsStorableTest, Matrix) {
+ auto* vec2 = create<sem::Vector>(create<sem::F32>(), 2u);
+ auto* vec3 = create<sem::Vector>(create<sem::F32>(), 3u);
+ auto* vec4 = create<sem::Vector>(create<sem::F32>(), 4u);
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec2, 4u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec3, 4u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 2u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 3u)));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Matrix>(vec4, 4u)));
+}
+
+TEST_F(ValidatorIsStorableTest, Pointer) {
+ auto* ptr = create<sem::Pointer>(
+ create<sem::I32>(), ast::StorageClass::kPrivate, ast::Access::kReadWrite);
+ EXPECT_FALSE(v()->IsStorable(ptr));
+}
+
+TEST_F(ValidatorIsStorableTest, Atomic) {
+ EXPECT_TRUE(v()->IsStorable(create<sem::Atomic>(create<sem::I32>())));
+ EXPECT_TRUE(v()->IsStorable(create<sem::Atomic>(create<sem::U32>())));
+}
+
+TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
+ EXPECT_TRUE(v()->IsStorable(arr));
+}
+
+TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
+ auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
+ EXPECT_TRUE(v()->IsStorable(arr));
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/validator_test_helper.cc b/src/tint/resolver/validator_test_helper.cc
new file mode 100644
index 0000000..123d5c7
--- /dev/null
+++ b/src/tint/resolver/validator_test_helper.cc
@@ -0,0 +1,27 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/validator_test_helper.h"
+
+#include <memory>
+
+namespace tint::resolver {
+
+TestHelper::TestHelper()
+ : validator_(
+ std::make_unique<Validator>(this->Symbols(), this->Diagnostics())) {}
+
+TestHelper::~TestHelper() = default;
+
+} // namespace tint::resolver
diff --git a/src/tint/resolver/validator_test_helper.h b/src/tint/resolver/validator_test_helper.h
new file mode 100644
index 0000000..3dd3ea2
--- /dev/null
+++ b/src/tint/resolver/validator_test_helper.h
@@ -0,0 +1,46 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_
+#define SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/resolver/validator.h"
+
+namespace tint::resolver {
+
+/// Helper class for testing
+class TestHelper : public ProgramBuilder {
+ public:
+ /// Constructor
+ TestHelper();
+
+ /// Destructor
+ ~TestHelper() override;
+
+ /// @return a pointer to the Validator
+ Validator* v() const { return validator_.get(); }
+
+ private:
+ std::unique_ptr<Validator> validator_;
+};
+
+class ValidatorTest : public TestHelper, public testing::Test {};
+
+} // namespace tint::resolver
+
+#endif // SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_