[tint] Add setters to sem variable types
Allows these to be constructed before resolving the initializer.
Change-Id: Id3976ea1ad05976358b8bfbcce462cc20e2ab93f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/155142
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/wgsl/helpers/append_vector.cc b/src/tint/lang/wgsl/helpers/append_vector.cc
index 0f0da17..8239d58 100644
--- a/src/tint/lang/wgsl/helpers/append_vector.cc
+++ b/src/tint/lang/wgsl/helpers/append_vector.cc
@@ -134,11 +134,9 @@
if (packed_el_sem_ty != scalar_sem->Type()->UnwrapRef()) {
// Cast scalar to the vector element type
auto* scalar_cast_ast = b->Call(packed_el_ast_ty, scalar_ast);
- auto* scalar_cast_target = b->create<sem::ValueConversion>(
- packed_el_sem_ty,
- b->create<sem::Parameter>(nullptr, 0u, scalar_sem->Type()->UnwrapRef(),
- core::AddressSpace::kUndefined, core::Access::kUndefined),
- core::EvaluationStage::kRuntime);
+ auto* param = b->create<sem::Parameter>(nullptr, 0u, scalar_sem->Type()->UnwrapRef());
+ auto* scalar_cast_target = b->create<sem::ValueConversion>(packed_el_sem_ty, param,
+ core::EvaluationStage::kRuntime);
auto* scalar_cast_sem = b->create<sem::Call>(
scalar_cast_ast, scalar_cast_target, core::EvaluationStage::kRuntime,
Vector<const sem::ValueExpression*, 1>{scalar_sem}, statement,
@@ -157,9 +155,8 @@
packed_sem_ty,
tint::Transform(packed,
[&](const tint::sem::ValueExpression* arg, size_t i) {
- return b->create<sem::Parameter>(
- nullptr, static_cast<uint32_t>(i), arg->Type()->UnwrapRef(),
- core::AddressSpace::kUndefined, core::Access::kUndefined);
+ return b->create<sem::Parameter>(nullptr, static_cast<uint32_t>(i),
+ arg->Type()->UnwrapRef());
}),
core::EvaluationStage::kRuntime);
auto* ctor_sem = b->create<sem::Call>(ctor_ast, ctor_target, core::EvaluationStage::kRuntime,
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index a1624f2..c2c1a95 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -230,7 +230,7 @@
return Switch(
v, //
[&](const ast::Var* var) { return Var(var, is_global); },
- [&](const ast::Let* let) { return Let(let, is_global); },
+ [&](const ast::Let* let) { return Let(let); },
[&](const ast::Override* override) { return Override(override); },
[&](const ast::Const* const_) { return Const(const_, is_global); },
[&](Default) {
@@ -242,15 +242,18 @@
});
}
-sem::Variable* Resolver::Let(const ast::Let* v, bool is_global) {
- const core::type::Type* ty = nullptr;
+sem::Variable* Resolver::Let(const ast::Let* v) {
+ auto* sem = b.create<sem::LocalVariable>(v, current_statement_);
+ sem->SetStage(core::EvaluationStage::kRuntime);
+ b.Sem().Add(v, sem);
// If the variable has a declared type, resolve it.
if (v->type) {
- ty = Type(v->type);
- if (!ty) {
+ auto* ty = Type(v->type);
+ if (TINT_UNLIKELY(!ty)) {
return nullptr;
}
+ sem->SetType(ty);
}
for (auto* attribute : v->attributes) {
@@ -267,53 +270,42 @@
}
}
- if (!v->initializer) {
+ if (TINT_UNLIKELY(!v->initializer)) {
AddError("'let' declaration must have an initializer", v->source);
return nullptr;
}
- auto* rhs = Load(Materialize(ValueExpression(v->initializer), ty));
- if (!rhs) {
+ auto* rhs = Load(Materialize(ValueExpression(v->initializer), sem->Type()));
+ if (TINT_UNLIKELY(!rhs)) {
return nullptr;
}
+ sem->SetInitializer(rhs);
// If the variable has no declared type, infer it from the RHS
- if (!ty) {
- ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
+ if (!sem->Type()) {
+ sem->SetType(rhs->Type()->UnwrapRef()); // Implicit load of RHS
}
- if (rhs && !validator_.VariableInitializer(v, ty, rhs)) {
+ if (TINT_UNLIKELY(rhs && !validator_.VariableInitializer(v, sem->Type(), rhs))) {
return nullptr;
}
if (!ApplyAddressSpaceUsageToType(core::AddressSpace::kUndefined,
- const_cast<core::type::Type*>(ty), v->source)) {
+ const_cast<core::type::Type*>(sem->Type()), v->source)) {
AddNote("while instantiating 'let' " + v->name->symbol.Name(), v->source);
return nullptr;
}
- sem::Variable* sem = nullptr;
- if (is_global) {
- sem =
- b.create<sem::GlobalVariable>(v, ty, core::EvaluationStage::kRuntime,
- core::AddressSpace::kUndefined, core::Access::kUndefined,
- /* constant_value */ nullptr, std::nullopt, std::nullopt);
- } else {
- sem = b.create<sem::LocalVariable>(v, ty, core::EvaluationStage::kRuntime,
- core::AddressSpace::kUndefined, core::Access::kUndefined,
- current_statement_,
- /* constant_value */ nullptr);
- }
-
- sem->SetInitializer(rhs);
- b.Sem().Add(v, sem);
return sem;
}
sem::Variable* Resolver::Override(const ast::Override* v) {
- const core::type::Type* ty = nullptr;
+ auto* sem = b.create<sem::GlobalVariable>(v);
+ b.Sem().Add(v, sem);
+ sem->SetStage(core::EvaluationStage::kOverride);
// If the variable has a declared type, resolve it.
+ const core::type::Type* ty = nullptr;
if (v->type) {
ty = Type(v->type);
if (!ty) {
@@ -321,31 +313,31 @@
}
}
- const sem::ValueExpression* rhs = nullptr;
-
// Does the variable have an initializer?
+ const sem::ValueExpression* init = nullptr;
if (v->initializer) {
// Note: RHS must be a const or override expression, which excludes references.
// So there's no need to load or unwrap references here.
-
ExprEvalStageConstraint constraint{core::EvaluationStage::kOverride,
"override initializer"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- rhs = Materialize(ValueExpression(v->initializer), ty);
- if (!rhs) {
+ init = Materialize(ValueExpression(v->initializer), ty);
+ if (TINT_UNLIKELY(!init)) {
return nullptr;
}
+ sem->SetInitializer(init);
- // If the variable has no declared type, infer it from the RHS
+ // If the variable has no declared type, infer it from the initializer
if (!ty) {
- ty = rhs->Type();
+ ty = init->Type();
}
} else if (!ty) {
AddError("override declaration requires a type or initializer", v->source);
return nullptr;
}
+ sem->SetType(ty);
- if (rhs && !validator_.VariableInitializer(v, ty, rhs)) {
+ if (init && !validator_.VariableInitializer(v, ty, init)) {
return nullptr;
}
@@ -355,12 +347,6 @@
return nullptr;
}
- auto* sem =
- b.create<sem::GlobalVariable>(v, ty, core::EvaluationStage::kOverride,
- core::AddressSpace::kUndefined, core::Access::kUndefined,
- /* constant_value */ nullptr, std::nullopt, std::nullopt);
- sem->SetInitializer(rhs);
-
for (auto* attribute : v->attributes) {
Mark(attribute);
bool ok = Switch(
@@ -408,25 +394,17 @@
}
}
- b.Sem().Add(v, sem);
return sem;
}
sem::Variable* Resolver::Const(const ast::Const* c, bool is_global) {
- const core::type::Type* ty = nullptr;
-
- // If the variable has a declared type, resolve it.
- if (c->type) {
- ty = Type(c->type);
- if (!ty) {
- return nullptr;
- }
+ sem::Variable* sem = nullptr;
+ if (is_global) {
+ sem = b.create<sem::GlobalVariable>(c);
+ } else {
+ sem = b.create<sem::LocalVariable>(c, current_statement_);
}
-
- if (!c->initializer) {
- AddError("'const' declaration must have an initializer", c->source);
- return nullptr;
- }
+ b.Sem().Add(c, sem);
for (auto* attribute : c->attributes) {
Mark(attribute);
@@ -440,31 +418,47 @@
}
}
- const sem::ValueExpression* rhs = nullptr;
- {
- ExprEvalStageConstraint constraint{core::EvaluationStage::kConstant, "const initializer"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- rhs = ValueExpression(c->initializer);
- if (!rhs) {
- return nullptr;
- }
+ if (TINT_UNLIKELY(!c->initializer)) {
+ AddError("'const' declaration must have an initializer", c->source);
+ return nullptr;
+ }
+
+ ExprEvalStageConstraint constraint{core::EvaluationStage::kConstant, "const initializer"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+ const auto* init = ValueExpression(c->initializer);
+ if (TINT_UNLIKELY(!init)) {
+ return nullptr;
}
// Note: RHS must be a const expression, which excludes references.
// So there's no need to load or unwrap references here.
+ // If the variable has a declared type, resolve it.
+ const core::type::Type* ty = nullptr;
+ if (c->type) {
+ ty = Type(c->type);
+ if (TINT_UNLIKELY(!ty)) {
+ return nullptr;
+ }
+ }
+
if (ty) {
// If an explicit type was specified, materialize to that type
- rhs = Materialize(rhs, ty);
- if (!rhs) {
+ init = Materialize(init, ty);
+ if (TINT_UNLIKELY(!init)) {
return nullptr;
}
} else {
// If no type was specified, infer it from the RHS
- ty = rhs->Type();
+ ty = init->Type();
}
- if (!validator_.VariableInitializer(c, ty, rhs)) {
+ sem->SetInitializer(init);
+ sem->SetStage(core::EvaluationStage::kConstant);
+ sem->SetConstantValue(init->ConstantValue());
+ sem->SetType(ty);
+
+ if (!validator_.VariableInitializer(c, ty, init)) {
return nullptr;
}
@@ -474,33 +468,30 @@
return nullptr;
}
- const auto value = rhs->ConstantValue();
- auto* sem = is_global
- ? static_cast<sem::Variable*>(b.create<sem::GlobalVariable>(
- c, ty, core::EvaluationStage::kConstant, core::AddressSpace::kUndefined,
- core::Access::kUndefined, value, std::nullopt, std::nullopt))
- : static_cast<sem::Variable*>(b.create<sem::LocalVariable>(
- c, ty, core::EvaluationStage::kConstant, core::AddressSpace::kUndefined,
- core::Access::kUndefined, current_statement_, value));
-
- sem->SetInitializer(rhs);
- b.Sem().Add(c, sem);
return sem;
}
sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) {
- const core::type::Type* storage_ty = nullptr;
+ sem::Variable* sem = nullptr;
+ sem::GlobalVariable* global = nullptr;
+ if (is_global) {
+ global = b.create<sem::GlobalVariable>(var);
+ sem = global;
+ } else {
+ sem = b.create<sem::LocalVariable>(var, current_statement_);
+ }
+ sem->SetStage(core::EvaluationStage::kRuntime);
+ b.Sem().Add(var, sem);
// If the variable has a declared type, resolve it.
+ const core::type::Type* storage_ty = nullptr;
if (auto ty = var->type) {
storage_ty = Type(ty);
- if (!storage_ty) {
+ if (TINT_UNLIKELY(!storage_ty)) {
return nullptr;
}
}
- const sem::ValueExpression* rhs = nullptr;
-
// Does the variable have a initializer?
if (var->initializer) {
ExprEvalStageConstraint constraint{
@@ -509,13 +500,15 @@
};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- rhs = Load(Materialize(ValueExpression(var->initializer), storage_ty));
- if (!rhs) {
+ auto* init = Load(Materialize(ValueExpression(var->initializer), storage_ty));
+ if (TINT_UNLIKELY(!init)) {
return nullptr;
}
+ sem->SetInitializer(init);
+
// If the variable has no declared type, infer it from the RHS
if (!storage_ty) {
- storage_ty = rhs->Type();
+ storage_ty = init->Type();
}
}
@@ -524,62 +517,61 @@
return nullptr;
}
- auto address_space = core::AddressSpace::kUndefined;
if (var->declared_address_space) {
- auto expr = AddressSpaceExpression(var->declared_address_space);
- if (TINT_UNLIKELY(!expr)) {
+ auto space = AddressSpaceExpression(var->declared_address_space);
+ if (TINT_UNLIKELY(!space)) {
return nullptr;
}
- address_space = expr->Value();
+ sem->SetAddressSpace(space->Value());
} else {
// No declared address space. Infer from usage / type.
if (!is_global) {
- address_space = core::AddressSpace::kFunction;
+ sem->SetAddressSpace(core::AddressSpace::kFunction);
} else if (storage_ty->UnwrapRef()->is_handle()) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// If the store type is a texture type or a sampler type, then the
// variable declaration must not have a address space attribute. The
// address space will always be handle.
- address_space = core::AddressSpace::kHandle;
+ sem->SetAddressSpace(core::AddressSpace::kHandle);
}
}
- if (!is_global && address_space != core::AddressSpace::kFunction &&
+ if (!is_global && sem->AddressSpace() != core::AddressSpace::kFunction &&
validator_.IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreAddressSpace)) {
AddError("function-scope 'var' declaration must use 'function' address space", var->source);
return nullptr;
}
- auto access = core::Access::kUndefined;
if (var->declared_access) {
auto expr = AccessExpression(var->declared_access);
if (!expr) {
return nullptr;
}
- access = expr->Value();
+ sem->SetAccess(expr->Value());
} else {
- access = DefaultAccessForAddressSpace(address_space);
+ sem->SetAccess(DefaultAccessForAddressSpace(sem->AddressSpace()));
}
- if (rhs && !validator_.VariableInitializer(var, storage_ty, rhs)) {
+ sem->SetType(b.create<core::type::Reference>(sem->AddressSpace(), storage_ty, sem->Access()));
+
+ if (sem->Initializer() &&
+ !validator_.VariableInitializer(var, storage_ty, sem->Initializer())) {
return nullptr;
}
- auto* var_ty = b.create<core::type::Reference>(address_space, storage_ty, access);
-
- if (!ApplyAddressSpaceUsageToType(address_space, var_ty,
+ if (!ApplyAddressSpaceUsageToType(sem->AddressSpace(),
+ const_cast<core::type::Type*>(sem->Type()),
var->type ? var->type->source : var->source)) {
AddNote("while instantiating 'var' " + var->name->symbol.Name(), var->source);
return nullptr;
}
- sem::Variable* sem = nullptr;
if (is_global) {
- bool has_io_address_space =
- address_space == core::AddressSpace::kIn || address_space == core::AddressSpace::kOut;
+ bool has_io_address_space = sem->AddressSpace() == core::AddressSpace::kIn ||
+ sem->AddressSpace() == core::AddressSpace::kOut;
- std::optional<uint32_t> group, binding, location, index;
+ std::optional<uint32_t> group, binding;
for (auto* attribute : var->attributes) {
Mark(attribute);
enum Status { kSuccess, kErrored, kInvalid };
@@ -609,7 +601,7 @@
if (!value) {
return kErrored;
}
- location = value.Get();
+ global->SetLocation(value.Get());
return kSuccess;
},
[&](const ast::IndexAttribute* attr) {
@@ -620,7 +612,7 @@
if (!value) {
return kErrored;
}
- index = value.Get();
+ global->SetIndex(value.Get());
return kSuccess;
},
[&](const ast::BuiltinAttribute* attr) {
@@ -657,13 +649,9 @@
}
}
- std::optional<BindingPoint> binding_point;
if (group && binding) {
- binding_point = BindingPoint{group.value(), binding.value()};
+ global->SetBindingPoint(BindingPoint{group.value(), binding.value()});
}
- sem = b.create<sem::GlobalVariable>(
- var, var_ty, core::EvaluationStage::kRuntime, address_space, access,
- /* constant_value */ nullptr, binding_point, location, index);
} else {
for (auto* attribute : var->attributes) {
@@ -679,13 +667,8 @@
return nullptr;
}
}
- sem = b.create<sem::LocalVariable>(var, var_ty, core::EvaluationStage::kRuntime,
- address_space, access, current_statement_,
- /* constant_value */ nullptr);
}
- sem->SetInitializer(rhs);
- b.Sem().Add(var, sem);
return sem;
}
@@ -694,23 +677,25 @@
uint32_t index) {
Mark(param->name);
+ auto* sem = b.create<sem::Parameter>(param, index);
+ b.Sem().Add(param, sem);
+
auto add_note = [&] {
AddNote("while instantiating parameter " + param->name->symbol.Name(), param->source);
};
- std::optional<uint32_t> location, group, binding;
-
if (func->IsEntryPoint()) {
+ std::optional<uint32_t> group, binding;
for (auto* attribute : param->attributes) {
Mark(attribute);
bool ok = Switch(
attribute, //
[&](const ast::LocationAttribute* attr) {
auto value = LocationAttribute(attr);
- if (!value) {
+ if (TINT_UNLIKELY(!value)) {
return false;
}
- location = value.Get();
+ sem->SetLocation(value.Get());
return true;
},
[&](const ast::BuiltinAttribute* attr) -> bool { return BuiltinAttribute(attr); },
@@ -728,7 +713,7 @@
return false;
}
auto value = GroupAttribute(attr);
- if (!value) {
+ if (TINT_UNLIKELY(!value)) {
return false;
}
group = value.Get();
@@ -741,7 +726,7 @@
return false;
}
auto value = BindingAttribute(attr);
- if (!value) {
+ if (TINT_UNLIKELY(!value)) {
return false;
}
binding = value.Get();
@@ -755,6 +740,9 @@
return nullptr;
}
}
+ if (group && binding) {
+ sem->SetBindingPoint(BindingPoint{group.value(), binding.value()});
+ }
} else {
for (auto* attribute : param->attributes) {
Mark(attribute);
@@ -781,9 +769,10 @@
}
core::type::Type* ty = Type(param->type);
- if (!ty) {
+ if (TINT_UNLIKELY(!ty)) {
return nullptr;
}
+ sem->SetType(ty);
if (!ApplyAddressSpaceUsageToType(core::AddressSpace::kUndefined, ty, param->type->source)) {
add_note();
@@ -801,16 +790,6 @@
}
}
- std::optional<BindingPoint> binding_point;
- if (group && binding) {
- binding_point = BindingPoint{group.value(), binding.value()};
- }
-
- auto* sem = b.create<sem::Parameter>(param, index, ty, core::AddressSpace::kUndefined,
- core::Access::kUndefined, core::ParameterUsage::kNone,
- binding_point, location);
- b.Sem().Add(param, sem);
-
if (!validator_.Parameter(sem)) {
return nullptr;
}
@@ -2091,9 +2070,7 @@
if (match->info->flags.Contains(OverloadFlag::kIsConstructor)) {
// Type constructor
auto params = Transform(match->parameters, [&](auto& p, size_t i) {
- return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type,
- core::AddressSpace::kUndefined,
- core::Access::kUndefined, p.usage);
+ return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type, p.usage);
});
target_sem = constructors_.GetOrCreate(match.Get(), [&] {
return b.create<sem::ValueConstructor>(match->return_type, std::move(params),
@@ -2102,9 +2079,8 @@
} else {
// Type conversion
target_sem = converters_.GetOrCreate(match.Get(), [&] {
- auto param = b.create<sem::Parameter>(
- nullptr, 0u, match->parameters[0].type, core::AddressSpace::kUndefined,
- core::Access::kUndefined, match->parameters[0].usage);
+ auto* param = b.create<sem::Parameter>(nullptr, 0u, match->parameters[0].type,
+ match->parameters[0].usage);
return b.create<sem::ValueConversion>(match->return_type, param, overload_stage);
});
}
@@ -2199,12 +2175,9 @@
ArrayConstructorSig{{arr, args.Length(), args_stage}},
[&]() -> sem::ValueConstructor* {
auto params = tint::Transform(args, [&](auto, size_t i) {
- return b.create<sem::Parameter>(
- nullptr, // declaration
- static_cast<uint32_t>(i), // index
- arr->ElemType(), // type
- core::AddressSpace::kUndefined, // address_space
- core::Access::kUndefined);
+ return b.create<sem::Parameter>(nullptr, // declaration
+ static_cast<uint32_t>(i), // index
+ arr->ElemType());
});
return b.create<sem::ValueConstructor>(arr, std::move(params), args_stage);
});
@@ -2227,12 +2200,10 @@
Vector<sem::Parameter*, 8> params;
params.Resize(std::min(args.Length(), str->Members().Length()));
for (size_t i = 0, n = params.Length(); i < n; i++) {
- params[i] = b.create<sem::Parameter>(
- nullptr, // declaration
- static_cast<uint32_t>(i), // index
- str->Members()[i]->Type(), // type
- core::AddressSpace::kUndefined, // address_space
- core::Access::kUndefined); // access
+ params[i] =
+ b.create<sem::Parameter>(nullptr, // declaration
+ static_cast<uint32_t>(i), // index
+ str->Members()[i]->Type()); // type
}
return b.create<sem::ValueConstructor>(str, std::move(params), args_stage);
});
@@ -2405,9 +2376,7 @@
// De-duplicate builtins that are identical.
auto* target = builtins_.GetOrCreate(std::make_pair(overload.Get(), fn), [&] {
auto params = Transform(overload->parameters, [&](auto& p, size_t i) {
- return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type,
- core::AddressSpace::kUndefined,
- core::Access::kUndefined, p.usage);
+ return b.create<sem::Parameter>(nullptr, static_cast<uint32_t>(i), p.type, p.usage);
});
sem::PipelineStageSet supported_stages;
auto flags = overload->info->flags;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index 2ba9a26..3dae493 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -507,8 +507,7 @@
/// @note this method does not resolve the attributes as these are context-dependent (global,
/// local)
/// @param var the variable
- /// @param is_global true if this is module scope, otherwise function scope
- sem::Variable* Let(const ast::Let* var, bool is_global);
+ sem::Variable* Let(const ast::Let* var);
/// @returns the semantic info for the module-scope `ast::Override` `v`. If an error is raised,
/// nullptr is returned.
diff --git a/src/tint/lang/wgsl/sem/variable.cc b/src/tint/lang/wgsl/sem/variable.cc
index bfff38f..d86f4ec 100644
--- a/src/tint/lang/wgsl/sem/variable.cc
+++ b/src/tint/lang/wgsl/sem/variable.cc
@@ -28,62 +28,26 @@
TINT_INSTANTIATE_TYPEINFO(tint::sem::VariableUser);
namespace tint::sem {
-Variable::Variable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const core::constant::Value* constant_value)
- : declaration_(declaration),
- type_(type),
- stage_(stage),
- address_space_(address_space),
- access_(access),
- constant_value_(constant_value) {}
+Variable::Variable(const ast::Variable* declaration) : declaration_(declaration) {}
Variable::~Variable() = default;
-LocalVariable::LocalVariable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const sem::Statement* statement,
- const core::constant::Value* constant_value)
- : Base(declaration, type, stage, address_space, access, constant_value),
- statement_(statement) {}
+LocalVariable::LocalVariable(const ast::Variable* declaration, const sem::Statement* statement)
+ : Base(declaration), statement_(statement) {}
LocalVariable::~LocalVariable() = default;
-GlobalVariable::GlobalVariable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const core::constant::Value* constant_value,
- std::optional<tint::BindingPoint> binding_point,
- std::optional<uint32_t> location,
- std::optional<uint32_t> index)
- : Base(declaration, type, stage, address_space, access, constant_value),
- binding_point_(binding_point),
- location_(location),
- index_(index) {}
+GlobalVariable::GlobalVariable(const ast::Variable* declaration) : Base(declaration) {}
GlobalVariable::~GlobalVariable() = default;
Parameter::Parameter(const ast::Parameter* declaration,
- uint32_t index,
- const core::type::Type* type,
- core::AddressSpace address_space,
- core::Access access,
- const core::ParameterUsage usage /* = ParameterUsage::kNone */,
- std::optional<tint::BindingPoint> binding_point /* = {} */,
- std::optional<uint32_t> location /* = std::nullopt */)
- : Base(declaration, type, core::EvaluationStage::kRuntime, address_space, access, nullptr),
- index_(index),
- usage_(usage),
- binding_point_(binding_point),
- location_(location) {}
+ uint32_t index /* = 0 */,
+ const core::type::Type* type /* = nullptr */,
+ core::ParameterUsage usage /* = core::ParameterUsage::kNone */)
+ : Base(declaration), index_(index), usage_(usage) {
+ SetType(type);
+}
Parameter::~Parameter() = default;
diff --git a/src/tint/lang/wgsl/sem/variable.h b/src/tint/lang/wgsl/sem/variable.h
index a894a83..82d9603 100644
--- a/src/tint/lang/wgsl/sem/variable.h
+++ b/src/tint/lang/wgsl/sem/variable.h
@@ -49,17 +49,7 @@
public:
/// Constructor
/// @param declaration the AST declaration node
- /// @param type the variable type
- /// @param stage the evaluation stage for an expression of this variable type
- /// @param address_space the variable address space
- /// @param access the variable access control type
- /// @param constant_value the constant value for the variable. May be null
- Variable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const core::constant::Value* constant_value);
+ explicit Variable(const ast::Variable* declaration);
/// Destructor
~Variable() override;
@@ -67,29 +57,44 @@
/// @returns the AST declaration node
const ast::Variable* Declaration() const { return declaration_; }
+ /// @param type the variable type
+ void SetType(const core::type::Type* type) { type_ = type; }
+
/// @returns the canonical type for the variable
const core::type::Type* Type() const { return type_; }
+ /// @param stage the evaluation stage for an expression of this variable type
+ void SetStage(core::EvaluationStage stage) { stage_ = stage; }
+
/// @returns the evaluation stage for an expression of this variable type
core::EvaluationStage Stage() const { return stage_; }
+ /// @param space the variable address space
+ void SetAddressSpace(core::AddressSpace space) { address_space_ = space; }
+
/// @returns the address space for the variable
core::AddressSpace AddressSpace() const { return address_space_; }
+ /// @param access the variable access control type
+ void SetAccess(core::Access access) { access_ = access; }
+
/// @returns the access control for the variable
core::Access Access() const { return access_; }
+ /// @param value the constant value for the variable. May be null
+ void SetConstantValue(const core::constant::Value* value) { constant_value_ = value; }
+
/// @return the constant value of this expression
const core::constant::Value* ConstantValue() const { return constant_value_; }
- /// @returns the variable initializer expression, or nullptr if the variable
- /// does not have one.
- const ValueExpression* Initializer() const { return initializer_; }
-
/// Sets the variable initializer expression.
/// @param initializer the initializer expression to assign to this variable.
void SetInitializer(const ValueExpression* initializer) { initializer_ = initializer; }
+ /// @returns the variable initializer expression, or nullptr if the variable
+ /// does not have one.
+ const ValueExpression* Initializer() const { return initializer_; }
+
/// @returns the expressions that use the variable
VectorRef<const VariableUser*> Users() const { return users_; }
@@ -97,12 +102,12 @@
void AddUser(const VariableUser* user) { users_.Push(user); }
private:
- const ast::Variable* const declaration_;
- const core::type::Type* const type_;
- const core::EvaluationStage stage_;
- const core::AddressSpace address_space_;
- const core::Access access_;
- const core::constant::Value* constant_value_;
+ const ast::Variable* const declaration_ = nullptr;
+ const core::type::Type* type_ = nullptr;
+ core::EvaluationStage stage_ = core::EvaluationStage::kRuntime;
+ core::AddressSpace address_space_ = core::AddressSpace::kUndefined;
+ core::Access access_ = core::Access::kUndefined;
+ const core::constant::Value* constant_value_ = nullptr;
const ValueExpression* initializer_ = nullptr;
tint::Vector<const VariableUser*, 8> users_;
};
@@ -112,19 +117,8 @@
public:
/// Constructor
/// @param declaration the AST declaration node
- /// @param type the variable type
- /// @param stage the evaluation stage for an expression of this variable type
- /// @param address_space the variable address space
- /// @param access the variable access control type
/// @param statement the statement that declared this local variable
- /// @param constant_value the constant value for the variable. May be null
- LocalVariable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const sem::Statement* statement,
- const core::constant::Value* constant_value);
+ LocalVariable(const ast::Variable* declaration, const sem::Statement* statement);
/// Destructor
~LocalVariable() override;
@@ -132,13 +126,13 @@
/// @returns the statement that declares this local variable
const sem::Statement* Statement() const { return statement_; }
- /// @returns the Type, Function or Variable that this local variable shadows
- const CastableBase* Shadows() const { return shadows_; }
-
/// Sets the Type, Function or Variable that this local variable shadows
/// @param shadows the Type, Function or Variable that this variable shadows
void SetShadows(const CastableBase* shadows) { shadows_ = shadows; }
+ /// @returns the Type, Function or Variable that this local variable shadows
+ const CastableBase* Shadows() const { return shadows_; }
+
private:
const sem::Statement* const statement_;
const CastableBase* shadows_ = nullptr;
@@ -149,30 +143,16 @@
public:
/// Constructor
/// @param declaration the AST declaration node
- /// @param type the variable type
- /// @param stage the evaluation stage for an expression of this variable type
- /// @param address_space the variable address space
- /// @param access the variable access control type
- /// @param constant_value the constant value for the variable. May be null
- /// @param binding_point the optional resource binding point of the variable
- /// @param location the location value if provided
- /// @param index the index value if provided
- ///
- /// Note, a GlobalVariable generally doesn't have a `location` in WGSL, as it isn't allowed by
- /// the spec. The location maybe attached by transforms such as CanonicalizeEntryPointIO.
- GlobalVariable(const ast::Variable* declaration,
- const core::type::Type* type,
- core::EvaluationStage stage,
- core::AddressSpace address_space,
- core::Access access,
- const core::constant::Value* constant_value,
- std::optional<tint::BindingPoint> binding_point = std::nullopt,
- std::optional<uint32_t> location = std::nullopt,
- std::optional<uint32_t> index = std::nullopt);
+ explicit GlobalVariable(const ast::Variable* declaration);
/// Destructor
~GlobalVariable() override;
+ /// @param binding_point the resource binding point for the parameter
+ void SetBindingPoint(std::optional<tint::BindingPoint> binding_point) {
+ binding_point_ = binding_point;
+ }
+
/// @returns the resource binding point for the variable
std::optional<tint::BindingPoint> BindingPoint() const { return binding_point_; }
@@ -182,15 +162,22 @@
/// @returns the pipeline constant ID associated with the variable
tint::OverrideId OverrideId() const { return override_id_; }
+ /// @param location the location value for the parameter, if set
+ /// @note a GlobalVariable generally doesn't have a `location` in WGSL, as it isn't allowed by
+ /// the spec. The location maybe attached by transforms such as CanonicalizeEntryPointIO.
+ void SetLocation(std::optional<uint32_t> location) { location_ = location; }
+
/// @returns the location value for the parameter, if set
std::optional<uint32_t> Location() const { return location_; }
+ /// @param index the index value for the parameter, if set
+ void SetIndex(std::optional<uint32_t> index) { index_ = index; }
+
/// @returns the index value for the parameter, if set
std::optional<uint32_t> Index() const { return index_; }
private:
- const std::optional<tint::BindingPoint> binding_point_;
-
+ std::optional<tint::BindingPoint> binding_point_;
tint::OverrideId override_id_;
std::optional<uint32_t> location_;
std::optional<uint32_t> index_;
@@ -199,23 +186,15 @@
/// Parameter is a function parameter
class Parameter final : public Castable<Parameter, Variable> {
public:
- /// Constructor for function parameters
+ /// Constructor
/// @param declaration the AST declaration node
- /// @param index the index of the parmeter in the function
+ /// @param index the index of the parameter in the function
/// @param type the variable type
- /// @param address_space the variable address space
- /// @param access the variable access control type
- /// @param usage the semantic usage for the parameter
- /// @param binding_point the optional resource binding point of the parameter
- /// @param location the location value, if set
+ /// @param usage the parameter usage
Parameter(const ast::Parameter* declaration,
- uint32_t index,
- const core::type::Type* type,
- core::AddressSpace address_space,
- core::Access access,
- const core::ParameterUsage usage = core::ParameterUsage::kNone,
- std::optional<tint::BindingPoint> binding_point = {},
- std::optional<uint32_t> location = std::nullopt);
+ uint32_t index = 0,
+ const core::type::Type* type = nullptr,
+ core::ParameterUsage usage = core::ParameterUsage::kNone);
/// Destructor
~Parameter() override;
@@ -225,38 +204,52 @@
return static_cast<const ast::Parameter*>(Variable::Declaration());
}
+ /// @param index the index value for the parameter, if set
+ void SetIndex(uint32_t index) { index_ = index; }
+
/// @return the index of the parameter in the function
uint32_t Index() const { return index_; }
+ /// @param usage the semantic usage for the parameter
+ void SetUsage(core::ParameterUsage usage) { usage_ = usage; }
+
/// @returns the semantic usage for the parameter
core::ParameterUsage Usage() const { return usage_; }
- /// @returns the CallTarget owner of this parameter
- CallTarget const* Owner() const { return owner_; }
-
/// @param owner the CallTarget owner of this parameter
- void SetOwner(CallTarget const* owner) { owner_ = owner; }
+ void SetOwner(const CallTarget* owner) { owner_ = owner; }
- /// @returns the Type, Function or Variable that this local variable shadows
- const CastableBase* Shadows() const { return shadows_; }
+ /// @returns the CallTarget owner of this parameter
+ const CallTarget* Owner() const { return owner_; }
/// Sets the Type, Function or Variable that this local variable shadows
/// @param shadows the Type, Function or Variable that this variable shadows
void SetShadows(const CastableBase* shadows) { shadows_ = shadows; }
+ /// @returns the Type, Function or Variable that this local variable shadows
+ const CastableBase* Shadows() const { return shadows_; }
+
+ /// @param binding_point the resource binding point for the parameter
+ void SetBindingPoint(std::optional<tint::BindingPoint> binding_point) {
+ binding_point_ = binding_point;
+ }
+
/// @returns the resource binding point for the parameter
std::optional<tint::BindingPoint> BindingPoint() const { return binding_point_; }
+ /// @param location the location value for the parameter, if set
+ void SetLocation(std::optional<uint32_t> location) { location_ = location; }
+
/// @returns the location value for the parameter, if set
std::optional<uint32_t> Location() const { return location_; }
private:
- const uint32_t index_;
- const core::ParameterUsage usage_;
+ uint32_t index_ = 0;
+ core::ParameterUsage usage_ = core::ParameterUsage::kNone;
CallTarget const* owner_ = nullptr;
const CastableBase* shadows_ = nullptr;
- const std::optional<tint::BindingPoint> binding_point_;
- const std::optional<uint32_t> location_;
+ std::optional<tint::BindingPoint> binding_point_;
+ std::optional<uint32_t> location_;
};
/// VariableUser holds the semantic information for an identifier expression