Add semantic::Variable::Type() and use it instead of ast::Variable::type()
In anticipation of adding support for type inference, no longer use
ast::Variable::type() everywhere, as it will eventually return nullptr
for type-inferred variables. Instead, the Resolver now stores the final
resolved type into the semantic::Variable, and nearly all code now makes
use of that.
ast::Variable::type() has been renamed to ast::Variable::declared_type()
to help make its usage clear, and to distinguish it from
semantic::Variable::Type().
Fixed tests that failed after this change because variables were missing
VariableDeclStatements, so there was no path to the variables during
resolving, and thus no semantic info generated for them.
Bug: tint:672
Change-Id: I0125e2f555839a4892248dc6739a72e9c7f51b1e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46100
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index b2d3174..6e438d8 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -125,7 +125,9 @@
out << "__func" + return_type_->type_name();
for (auto* param : params_) {
- out << param->type()->type_name();
+ // No need for the semantic::Variable here, functions params must have a
+ // type
+ out << param->declared_type()->type_name();
}
return out.str();
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index a7a5f1b..cab6350 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -25,20 +25,20 @@
Variable::Variable(const Source& source,
const Symbol& sym,
- StorageClass sc,
- type::Type* type,
+ StorageClass declared_storage_class,
+ type::Type* declared_type,
bool is_const,
Expression* constructor,
DecorationList decorations)
: Base(source),
symbol_(sym),
- type_(type),
+ declared_type_(declared_type),
is_const_(is_const),
constructor_(constructor),
decorations_(std::move(decorations)),
- declared_storage_class_(sc) {
+ declared_storage_class_(declared_storage_class) {
TINT_ASSERT(symbol_.IsValid());
- TINT_ASSERT(type_);
+ TINT_ASSERT(declared_type_);
}
Variable::Variable(Variable&&) = default;
@@ -94,7 +94,7 @@
Variable* Variable::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source());
auto sym = ctx->Clone(symbol());
- auto* ty = ctx->Clone(type());
+ auto* ty = ctx->Clone(declared_type());
auto* ctor = ctx->Clone(constructor());
auto decos = ctx->Clone(decorations());
return ctx->dst->create<Variable>(src, sym, declared_storage_class(), ty,
@@ -111,7 +111,7 @@
out << (var_sem ? var_sem->StorageClass() : declared_storage_class())
<< std::endl;
make_indent(out, indent);
- out << type_->type_name() << std::endl;
+ out << declared_type_->type_name() << std::endl;
}
void Variable::constructor_to_str(const semantic::Info& sem,
diff --git a/src/ast/variable.h b/src/ast/variable.h
index a7c7b65..f5f37cb 100644
--- a/src/ast/variable.h
+++ b/src/ast/variable.h
@@ -79,15 +79,15 @@
/// Create a variable
/// @param source the variable source
/// @param sym the variable symbol
- /// @param sc the declared storage class
- /// @param type the value type
+ /// @param declared_storage_class the declared storage class
+ /// @param declared_type the declared variable type
/// @param is_const true if the variable is const
/// @param constructor the constructor expression
/// @param decorations the variable decorations
Variable(const Source& source,
const Symbol& sym,
- StorageClass sc,
- type::Type* type,
+ StorageClass declared_storage_class,
+ type::Type* declared_type,
bool is_const,
Expression* constructor,
DecorationList decorations);
@@ -99,8 +99,8 @@
/// @returns the variable symbol
const Symbol& symbol() const { return symbol_; }
- /// @returns the variable's type.
- type::Type* type() const { return type_; }
+ /// @returns the declared type
+ type::Type* declared_type() const { return declared_type_; }
/// @returns the declared storage class
StorageClass declared_storage_class() const {
@@ -166,7 +166,7 @@
Symbol const symbol_;
// The value type if a const or formal paramter, and the store type if a var
- type::Type* const type_;
+ type::Type* const declared_type_;
bool const is_const_;
Expression* const constructor_;
DecorationList const decorations_;
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 4d5b967..2651f1a 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -27,7 +27,7 @@
EXPECT_EQ(v->symbol(), Symbol(1));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction);
- EXPECT_EQ(v->type(), ty.i32());
+ EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_EQ(v->source().range.begin.line, 0u);
EXPECT_EQ(v->source().range.begin.column, 0u);
EXPECT_EQ(v->source().range.end.line, 0u);
@@ -41,7 +41,7 @@
EXPECT_EQ(v->symbol(), Symbol(1));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate);
- EXPECT_EQ(v->type(), ty.f32());
+ EXPECT_EQ(v->declared_type(), ty.f32());
EXPECT_EQ(v->source().range.begin.line, 27u);
EXPECT_EQ(v->source().range.begin.column, 4u);
EXPECT_EQ(v->source().range.end.line, 27u);
@@ -55,7 +55,7 @@
EXPECT_EQ(v->symbol(), Symbol(1));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup);
- EXPECT_EQ(v->type(), ty.i32());
+ EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_EQ(v->source().range.begin.line, 27u);
EXPECT_EQ(v->source().range.begin.column, 4u);
EXPECT_EQ(v->source().range.end.line, 27u);
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index a45c811..5df2847 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -211,7 +211,7 @@
stage_variable.name = name;
stage_variable.component_type = ComponentType::kUnknown;
- auto* type = var->Declaration()->type()->UnwrapAll();
+ auto* type = var->Type()->UnwrapAll();
if (type->is_float_scalar_or_vector() || type->is_float_matrix()) {
stage_variable.component_type = ComponentType::kFloat;
} else if (type->is_unsigned_scalar_or_vector()) {
@@ -367,10 +367,9 @@
auto* func_sem = program_->Sem().Get(func);
for (auto& ruv : func_sem->ReferencedUniformVariables()) {
auto* var = ruv.first;
- auto* decl = var->Declaration();
auto binding_info = ruv.second;
- auto* unwrapped_type = decl->type()->UnwrapIfNeeded();
+ auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
auto* str = unwrapped_type->As<type::Struct>();
if (str == nullptr) {
continue;
@@ -492,7 +491,6 @@
auto* func_sem = program_->Sem().Get(func);
for (auto& ref : func_sem->ReferencedDepthTextureVariables()) {
auto* var = ref.first;
- auto* decl = var->Declaration();
auto binding_info = ref.second;
ResourceBinding entry;
@@ -500,7 +498,7 @@
entry.bind_group = binding_info.group->value();
entry.binding = binding_info.binding->value();
- auto* texture_type = decl->type()->UnwrapIfNeeded()->As<type::Texture>();
+ auto* texture_type = var->Type()->UnwrapIfNeeded()->As<type::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim());
@@ -537,10 +535,9 @@
std::vector<ResourceBinding> result;
for (auto& rsv : func_sem->ReferencedStorageBufferVariables()) {
auto* var = rsv.first;
- auto* decl = var->Declaration();
auto binding_info = rsv.second;
- auto* ac_type = decl->type()->As<type::AccessControl>();
+ auto* ac_type = var->Type()->As<type::AccessControl>();
if (ac_type == nullptr) {
continue;
}
@@ -549,7 +546,7 @@
continue;
}
- auto* str = decl->type()->UnwrapIfNeeded()->As<type::Struct>();
+ auto* str = var->Type()->UnwrapIfNeeded()->As<type::Struct>();
if (!str) {
continue;
}
@@ -591,7 +588,6 @@
: func_sem->ReferencedSampledTextureVariables();
for (auto& ref : referenced_variables) {
auto* var = ref.first;
- auto* decl = var->Declaration();
auto binding_info = ref.second;
ResourceBinding entry;
@@ -601,7 +597,7 @@
entry.bind_group = binding_info.group->value();
entry.binding = binding_info.binding->value();
- auto* texture_type = decl->type()->UnwrapIfNeeded()->As<type::Texture>();
+ auto* texture_type = var->Type()->UnwrapIfNeeded()->As<type::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim());
@@ -634,10 +630,9 @@
std::vector<ResourceBinding> result;
for (auto& ref : func_sem->ReferencedStorageTextureVariables()) {
auto* var = ref.first;
- auto* decl = var->Declaration();
auto binding_info = ref.second;
- auto* ac_type = decl->type()->As<type::AccessControl>();
+ auto* ac_type = var->Type()->As<type::AccessControl>();
if (ac_type == nullptr) {
continue;
}
@@ -654,7 +649,7 @@
entry.binding = binding_info.binding->value();
auto* texture_type =
- decl->type()->UnwrapIfNeeded()->As<type::StorageTexture>();
+ var->Type()->UnwrapIfNeeded()->As<type::StorageTexture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim());
diff --git a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
index c6f93d5..11d5c93 100644
--- a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
@@ -32,8 +32,8 @@
EXPECT_TRUE(e->is_const());
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- ASSERT_NE(e->type(), nullptr);
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ ASSERT_NE(e->declared_type(), nullptr);
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->source().range.begin.line, 1u);
EXPECT_EQ(e->source().range.begin.column, 7u);
@@ -112,8 +112,8 @@
EXPECT_TRUE(e->is_const());
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- ASSERT_NE(e->type(), nullptr);
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ ASSERT_NE(e->declared_type(), nullptr);
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->source().range.begin.line, 1u);
EXPECT_EQ(e->source().range.begin.column, 26u);
diff --git a/src/reader/wgsl/parser_impl_global_variable_decl_test.cc b/src/reader/wgsl/parser_impl_global_variable_decl_test.cc
index 89d6909..831d115 100644
--- a/src/reader/wgsl/parser_impl_global_variable_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_variable_decl_test.cc
@@ -31,7 +31,7 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput);
EXPECT_EQ(e->source().range.begin.line, 1u);
@@ -54,7 +54,7 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput);
EXPECT_EQ(e->source().range.begin.line, 1u);
@@ -79,8 +79,8 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- ASSERT_NE(e->type(), nullptr);
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ ASSERT_NE(e->declared_type(), nullptr);
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput);
EXPECT_EQ(e->source().range.begin.line, 1u);
@@ -109,8 +109,8 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
- ASSERT_NE(e->type(), nullptr);
- EXPECT_TRUE(e->type()->Is<type::F32>());
+ ASSERT_NE(e->declared_type(), nullptr);
+ EXPECT_TRUE(e->declared_type()->Is<type::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput);
EXPECT_EQ(e->source().range.begin.line, 1u);
@@ -180,7 +180,7 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s"));
- EXPECT_TRUE(e->type()->Is<type::Sampler>());
+ EXPECT_TRUE(e->declared_type()->Is<type::Sampler>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant);
}
@@ -196,7 +196,7 @@
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s"));
- EXPECT_TRUE(e->type()->UnwrapAll()->Is<type::Texture>());
+ EXPECT_TRUE(e->declared_type()->UnwrapAll()->Is<type::Texture>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant);
}
diff --git a/src/reader/wgsl/parser_impl_param_list_test.cc b/src/reader/wgsl/parser_impl_param_list_test.cc
index 0bf7686..c581145 100644
--- a/src/reader/wgsl/parser_impl_param_list_test.cc
+++ b/src/reader/wgsl/parser_impl_param_list_test.cc
@@ -30,7 +30,7 @@
EXPECT_EQ(e.value.size(), 1u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a"));
- EXPECT_EQ(e.value[0]->type(), i32);
+ EXPECT_EQ(e.value[0]->declared_type(), i32);
EXPECT_TRUE(e.value[0]->is_const());
ASSERT_EQ(e.value[0]->source().range.begin.line, 1u);
@@ -52,7 +52,7 @@
EXPECT_EQ(e.value.size(), 3u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a"));
- EXPECT_EQ(e.value[0]->type(), i32);
+ EXPECT_EQ(e.value[0]->declared_type(), i32);
EXPECT_TRUE(e.value[0]->is_const());
ASSERT_EQ(e.value[0]->source().range.begin.line, 1u);
@@ -61,7 +61,7 @@
ASSERT_EQ(e.value[0]->source().range.end.column, 2u);
EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("b"));
- EXPECT_EQ(e.value[1]->type(), f32);
+ EXPECT_EQ(e.value[1]->declared_type(), f32);
EXPECT_TRUE(e.value[1]->is_const());
ASSERT_EQ(e.value[1]->source().range.begin.line, 1u);
@@ -70,7 +70,7 @@
ASSERT_EQ(e.value[1]->source().range.end.column, 11u);
EXPECT_EQ(e.value[2]->symbol(), p->builder().Symbols().Get("c"));
- EXPECT_EQ(e.value[2]->type(), vec2);
+ EXPECT_EQ(e.value[2]->declared_type(), vec2);
EXPECT_TRUE(e.value[2]->is_const());
ASSERT_EQ(e.value[2]->source().range.begin.line, 1u);
@@ -109,7 +109,7 @@
ASSERT_EQ(e.value.size(), 2u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("coord"));
- EXPECT_EQ(e.value[0]->type(), vec4);
+ EXPECT_EQ(e.value[0]->declared_type(), vec4);
EXPECT_TRUE(e.value[0]->is_const());
auto decos0 = e.value[0]->decorations();
ASSERT_EQ(decos0.size(), 1u);
@@ -123,7 +123,7 @@
ASSERT_EQ(e.value[0]->source().range.end.column, 30u);
EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("loc1"));
- EXPECT_EQ(e.value[1]->type(), f32);
+ EXPECT_EQ(e.value[1]->declared_type(), f32);
EXPECT_TRUE(e.value[1]->is_const());
auto decos1 = e.value[1]->decorations();
ASSERT_EQ(decos1.size(), 1u);
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 189202a..bdf95be 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -201,7 +201,8 @@
return false;
}
} else if (auto* var = decl->As<ast::Variable>()) {
- variable_stack_.set_global(var->symbol(), CreateVariableInfo(var));
+ auto* info = CreateVariableInfo(var);
+ variable_stack_.set_global(var->symbol(), info);
if (var->has_constructor()) {
if (!Expression(var->constructor())) {
@@ -210,7 +211,7 @@
}
if (!ApplyStorageClassUsageToType(var->declared_storage_class(),
- var->type(), var->source())) {
+ info->type, var->source())) {
diagnostics_.add_note("while instantiating variable " +
builder_->Symbols().NameFor(var->symbol()),
var->source());
@@ -223,7 +224,8 @@
}
bool Resolver::ValidateParameter(const ast::Variable* param) {
- if (auto* r = param->type()->UnwrapAll()->As<type::Array>()) {
+ auto* type = variable_to_info_[param]->type;
+ if (auto* r = type->UnwrapAll()->As<type::Array>()) {
if (r->IsRuntimeArray()) {
diagnostics_.add_error(
"v-0015",
@@ -277,10 +279,6 @@
bool Resolver::Function(ast::Function* func) {
auto* func_info = function_infos_.Create<FunctionInfo>(func);
- if (!ValidateFunction(func)) {
- return false;
- }
-
ScopedAssignment<FunctionInfo*> sa(current_function_, func_info);
variable_stack_.push_scope();
@@ -293,6 +291,10 @@
}
variable_stack_.pop_scope();
+ if (!ValidateFunction(func)) {
+ return false;
+ }
+
// Register the function information _after_ processing the statements. This
// allows us to catch a function calling itself when determining the call
// information as this function doesn't exist until it's finished.
@@ -780,12 +782,12 @@
// A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type.
if (var->declaration->is_const()) {
- SetType(expr, var->declaration->type());
- } else if (var->declaration->type()->Is<type::Pointer>()) {
- SetType(expr, var->declaration->type());
+ SetType(expr, var->type);
+ } else if (var->type->Is<type::Pointer>()) {
+ SetType(expr, var->type);
} else {
- SetType(expr, builder_->create<type::Pointer>(var->declaration->type(),
- var->storage_class));
+ SetType(expr,
+ builder_->create<type::Pointer>(var->type, var->storage_class));
}
var->users.push_back(expr);
@@ -1200,15 +1202,18 @@
}
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
+ ast::Variable* var = stmt->variable();
+ type::Type* type = var->declared_type();
+
if (auto* ctor = stmt->variable()->constructor()) {
if (!Expression(ctor)) {
return false;
}
- auto* lhs_type = stmt->variable()->type();
auto* rhs_type = TypeOf(ctor);
- if (!IsValidAssignment(lhs_type, rhs_type)) {
+
+ if (!IsValidAssignment(type, rhs_type)) {
diagnostics_.add_error(
- "variable of type '" + lhs_type->FriendlyName(builder_->Symbols()) +
+ "variable of type '" + type->FriendlyName(builder_->Symbols()) +
"' cannot be initialized with a value of type '" +
rhs_type->FriendlyName(builder_->Symbols()) + "'",
stmt->source());
@@ -1216,10 +1221,8 @@
}
}
- auto* var = stmt->variable();
-
auto* info = CreateVariableInfo(var);
- variable_to_info_.emplace(var, info);
+ info->type = type;
variable_stack_.set(var->symbol(), info);
current_block_->decls.push_back(var);
@@ -1235,7 +1238,7 @@
}
}
- if (!ApplyStorageClassUsageToType(info->storage_class, var->type(),
+ if (!ApplyStorageClassUsageToType(info->storage_class, info->type,
var->source())) {
diagnostics_.add_note("while instantiating variable " +
builder_->Symbols().NameFor(var->symbol()),
@@ -1303,8 +1306,8 @@
}
users.push_back(sem_expr);
}
- sem.Add(var, builder_->create<semantic::Variable>(var, info->storage_class,
- std::move(users)));
+ sem.Add(var, builder_->create<semantic::Variable>(
+ var, info->type, info->storage_class, std::move(users)));
}
auto remap_vars = [&sem](const std::vector<VariableInfo*>& in) {
@@ -1812,7 +1815,9 @@
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
- : declaration(decl), storage_class(decl->declared_storage_class()) {}
+ : declaration(decl),
+ type(decl->declared_type()),
+ storage_class(decl->declared_storage_class()) {}
Resolver::VariableInfo::~VariableInfo() = default;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 3109dc6..3f39d73 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -93,6 +93,7 @@
~VariableInfo();
ast::Variable* const declaration;
+ type::Type* type;
ast::StorageClass storage_class;
std::vector<ast::IdentifierExpression*> users;
};
@@ -290,7 +291,7 @@
ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
- std::unordered_map<ast::Variable*, VariableInfo*> variable_to_info_;
+ std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<type::Struct*, StructInfo*> struct_info_;
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc
index 15e0b65..8e4dcbd 100644
--- a/src/semantic/sem_function.cc
+++ b/src/semantic/sem_function.cc
@@ -32,7 +32,8 @@
ParameterList parameters;
parameters.reserve(ast->params().size());
for (auto* param : ast->params()) {
- parameters.emplace_back(Parameter{param->type(), Parameter::Usage::kNone});
+ parameters.emplace_back(
+ Parameter{param->declared_type(), Parameter::Usage::kNone});
}
return parameters;
}
@@ -160,7 +161,8 @@
VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) {
- auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded();
+ auto* unwrapped_type =
+ var->Declaration()->declared_type()->UnwrapIfNeeded();
auto* storage_texture = unwrapped_type->As<type::StorageTexture>();
if (storage_texture == nullptr) {
continue;
@@ -182,7 +184,8 @@
VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) {
- auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded();
+ auto* unwrapped_type =
+ var->Declaration()->declared_type()->UnwrapIfNeeded();
auto* storage_texture = unwrapped_type->As<type::DepthTexture>();
if (storage_texture == nullptr) {
continue;
@@ -229,7 +232,8 @@
VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) {
- auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded();
+ auto* unwrapped_type =
+ var->Declaration()->declared_type()->UnwrapIfNeeded();
auto* sampler = unwrapped_type->As<type::Sampler>();
if (sampler == nullptr || sampler->kind() != kind) {
continue;
@@ -252,7 +256,8 @@
VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) {
- auto* unwrapped_type = var->Declaration()->type()->UnwrapIfNeeded();
+ auto* unwrapped_type =
+ var->Declaration()->declared_type()->UnwrapIfNeeded();
auto* texture = unwrapped_type->As<type::Texture>();
if (texture == nullptr) {
continue;
diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc
index 1051a6e..03dc034 100644
--- a/src/semantic/sem_variable.cc
+++ b/src/semantic/sem_variable.cc
@@ -19,10 +19,12 @@
namespace tint {
namespace semantic {
-Variable::Variable(ast::Variable* declaration,
+Variable::Variable(const ast::Variable* declaration,
+ type::Type* type,
ast::StorageClass storage_class,
std::vector<const Expression*> users)
: declaration_(declaration),
+ type_(type),
storage_class_(storage_class),
users_(std::move(users)) {}
diff --git a/src/semantic/variable.h b/src/semantic/variable.h
index 96f023a..0e606a2 100644
--- a/src/semantic/variable.h
+++ b/src/semantic/variable.h
@@ -37,9 +37,11 @@
public:
/// Constructor
/// @param declaration the AST declaration node
+ /// @param type the variable type
/// @param storage_class the variable storage class
/// @param users the expressions that use the variable
- explicit Variable(ast::Variable* declaration,
+ explicit Variable(const ast::Variable* declaration,
+ type::Type* type,
ast::StorageClass storage_class,
std::vector<const Expression*> users);
@@ -47,7 +49,10 @@
~Variable() override;
/// @returns the AST declaration node
- ast::Variable* Declaration() const { return declaration_; }
+ const ast::Variable* Declaration() const { return declaration_; }
+
+ /// @returns the type for the variable
+ type::Type* Type() const { return type_; }
/// @returns the storage class for the variable
ast::StorageClass StorageClass() const { return storage_class_; }
@@ -56,7 +61,8 @@
const std::vector<const Expression*>& Users() const { return users_; }
private:
- ast::Variable* const declaration_;
+ const ast::Variable* const declaration_;
+ type::Type* const type_;
ast::StorageClass const storage_class_;
std::vector<const Expression*> const users_;
};
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index eccd813..cdc0288 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -39,7 +39,7 @@
// Clone arguments outside of create() call to have deterministic ordering
auto source = ctx->Clone(in->source());
auto symbol = ctx->dst->Symbols().Register(new_name);
- auto* type = ctx->Clone(in->type());
+ auto* type = ctx->Clone(in->declared_type());
auto* constructor = ctx->Clone(in->constructor());
auto decorations = ctx->Clone(in->decorations());
return ctx->dst->create<ast::Variable>(
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 494ca38..8ed06ab 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -145,7 +145,8 @@
// Build a new structure to hold the non-struct input parameters.
ast::StructMemberList struct_members;
for (auto* param : func->params()) {
- if (param->type()->Is<type::Struct>()) {
+ auto* type = ctx.src->Sem().Get(param)->Type();
+ if (type->Is<type::Struct>()) {
// Already a struct, nothing to do.
continue;
}
@@ -159,14 +160,12 @@
auto* deco = param->decorations()[0];
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
// Create a struct member with the builtin decoration.
- struct_members.push_back(
- ctx.dst->Member(name, ctx.Clone(param->type()),
- ast::DecorationList{ctx.Clone(builtin)}));
+ struct_members.push_back(ctx.dst->Member(
+ name, ctx.Clone(type), ast::DecorationList{ctx.Clone(builtin)}));
} else if (auto* loc = deco->As<ast::LocationDecoration>()) {
// Create a struct member with the location decoration.
- struct_members.push_back(
- ctx.dst->Member(name, ctx.Clone(param->type()),
- ast::DecorationList{ctx.Clone(loc)}));
+ struct_members.push_back(ctx.dst->Member(
+ name, ctx.Clone(type), ast::DecorationList{ctx.Clone(loc)}));
} else {
TINT_ICE(ctx.dst->Diagnostics())
<< "Unsupported entry point parameter decoration";
@@ -195,7 +194,8 @@
// Replace the original parameters with function-scope constants.
for (auto* param : func->params()) {
- if (param->type()->Is<type::Struct>()) {
+ auto* type = ctx.src->Sem().Get(param)->Type();
+ if (type->Is<type::Struct>()) {
// Keep struct parameters unchanged.
new_parameters.push_back(ctx.Clone(param));
continue;
@@ -207,7 +207,7 @@
// Initialize it with the value extracted from the struct parameter.
auto func_const_symbol = ctx.dst->Symbols().Register(name);
auto* func_const =
- ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()),
+ ctx.dst->Const(func_const_symbol, ctx.Clone(type),
ctx.dst->MemberAccessor(struct_param_symbol, name));
new_body.push_back(ctx.dst->WrapInStatement(func_const));
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index 3110418..6245fc7 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -14,6 +14,7 @@
#include "src/transform/msl.h"
+#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -326,9 +327,10 @@
continue;
} else if (auto* loc = deco->As<ast::LocationDecoration>()) {
// Create a struct member with the location decoration.
- struct_members.push_back(ctx.dst->Member(
- ctx.src->Symbols().NameFor(param->symbol()),
- ctx.Clone(param->type()), ast::DecorationList{ctx.Clone(loc)}));
+ std::string name = ctx.src->Symbols().NameFor(param->symbol());
+ auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type());
+ struct_members.push_back(
+ ctx.dst->Member(name, type, ast::DecorationList{ctx.Clone(loc)}));
} else {
TINT_ICE(ctx.dst->Diagnostics())
<< "Unsupported entry point parameter decoration";
@@ -368,9 +370,9 @@
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the struct parameter.
auto func_const_symbol = ctx.dst->Symbols().Register(name);
- auto* func_const =
- ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()),
- ctx.dst->MemberAccessor(struct_param_symbol, name));
+ auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type());
+ auto* constructor = ctx.dst->MemberAccessor(struct_param_symbol, name);
+ auto* func_const = ctx.dst->Const(func_const_symbol, type, constructor);
new_body.push_back(ctx.dst->WrapInStatement(func_const));
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 0f872d3..06bd5f6 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -137,8 +137,8 @@
}
for (auto* param : func->params()) {
- Symbol new_var =
- HoistToInputVariables(ctx, func, param->type(), param->decorations());
+ Symbol new_var = HoistToInputVariables(
+ ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 78e96c7..638eb53 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -207,7 +207,7 @@
Source{}, // source
ctx.dst->Symbols().Register(name), // symbol
ast::StorageClass::kPrivate, // storage_class
- ctx.Clone(v->type()), // type
+ ctx.Clone(v->declared_type()), // type
false, // is_const
nullptr, // constructor
ast::DecorationList{}); // decorations
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index f9aea11..a00c5fa 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -219,7 +219,8 @@
// storable.
// - types match or the RHS can be dereferenced to equal the LHS type.
variable_stack_.set(symbol, decl->variable());
- if (auto* arr = decl->variable()->type()->UnwrapAll()->As<type::Array>()) {
+ if (auto* arr =
+ decl->variable()->declared_type()->UnwrapAll()->As<type::Array>()) {
if (arr->IsRuntimeArray()) {
add_error(
decl->source(), "v-0015",
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 1b940e8..c5e4a4c 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1515,11 +1515,13 @@
}
first = false;
- if (!EmitType(out, v->type(), builder_.Symbols().NameFor(v->symbol()))) {
+ auto* type = builder_.Sem().Get(v)->Type();
+
+ if (!EmitType(out, type, builder_.Symbols().NameFor(v->symbol()))) {
return false;
}
// Array name is output as part of the type
- if (!v->type()->Is<type::Array>()) {
+ if (!type->Is<type::Array>()) {
out << " " << builder_.Symbols().NameFor(v->symbol());
}
}
@@ -1541,8 +1543,8 @@
std::ostream& out,
ast::Function* func,
std::unordered_set<Symbol>& emitted_globals) {
- std::vector<std::pair<ast::Variable*, ast::Decoration*>> in_variables;
- std::vector<std::pair<ast::Variable*, ast::Decoration*>> outvariables;
+ std::vector<std::pair<const ast::Variable*, ast::Decoration*>> in_variables;
+ std::vector<std::pair<const ast::Variable*, ast::Decoration*>> outvariables;
auto* func_sem = builder_.Sem().Get(func);
auto func_sym = func->symbol();
@@ -1595,7 +1597,7 @@
}
// auto* set = data.second.set;
- auto* type = decl->type()->UnwrapIfNeeded();
+ auto* type = var->Type()->UnwrapIfNeeded();
if (auto* strct = type->As<type::Struct>()) {
out << "ConstantBuffer<" << builder_.Symbols().NameFor(strct->symbol())
<< "> " << builder_.Symbols().NameFor(decl->symbol())
@@ -1638,7 +1640,7 @@
}
auto* binding = data.second.binding;
- auto* ac = decl->type()->As<type::AccessControl>();
+ auto* ac = var->Type()->As<type::AccessControl>();
if (ac == nullptr) {
diagnostics_.add_error("access control type required for storage buffer");
return false;
@@ -1672,10 +1674,10 @@
for (auto& data : in_variables) {
auto* var = data.first;
auto* deco = data.second;
+ auto* type = builder_.Sem().Get(var)->Type();
make_indent(out);
- if (!EmitType(out, var->type(),
- builder_.Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
@@ -1722,10 +1724,10 @@
for (auto& data : outvariables) {
auto* var = data.first;
auto* deco = data.second;
+ auto* type = builder_.Sem().Get(var)->Type();
make_indent(out);
- if (!EmitType(out, var->type(),
- builder_.Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
@@ -1766,7 +1768,7 @@
for (auto* var : func_sem->ReferencedModuleVariables()) {
auto* decl = var->Declaration();
- auto* unwrapped_type = decl->type()->UnwrapAll();
+ auto* unwrapped_type = var->Type()->UnwrapAll();
if (!unwrapped_type->Is<type::Texture>() &&
!unwrapped_type->Is<type::Sampler>()) {
continue; // Not interested in this type
@@ -1776,7 +1778,7 @@
continue; // Global already emitted
}
- if (!EmitType(out, decl->type(), "")) {
+ if (!EmitType(out, var->Type(), "")) {
return false;
}
out << " " << namer_.NameFor(builder_.Symbols().NameFor(decl->symbol()))
@@ -1861,7 +1863,8 @@
// Emit entry point parameters.
for (auto* var : func->params()) {
- if (!var->type()->Is<type::Struct>()) {
+ auto* type = builder_.Sem().Get(var)->Type();
+ if (!type->Is<type::Struct>()) {
TINT_ICE(diagnostics_) << "Unsupported non-struct entry point parameter";
}
@@ -1870,7 +1873,7 @@
}
first = false;
- if (!EmitType(out, var->type(), "")) {
+ if (!EmitType(out, type, "")) {
return false;
}
@@ -2024,7 +2027,7 @@
if (var->constructor() != nullptr) {
out << constructor_out.str();
} else {
- if (!EmitZeroValue(out, var->type())) {
+ if (!EmitZeroValue(out, builder_.Sem().Get(var)->Type())) {
return false;
}
}
@@ -2678,10 +2681,11 @@
if (var->is_const()) {
out << "const ";
}
- if (!EmitType(out, var->type(), builder_.Symbols().NameFor(var->symbol()))) {
+ auto* type = builder_.Sem().Get(var)->Type();
+ if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
- if (!var->type()->Is<type::Array>()) {
+ if (!type->Is<type::Array>()) {
out << " " << builder_.Symbols().NameFor(var->symbol());
}
out << constructor_out.str() << ";" << std::endl;
@@ -2713,6 +2717,8 @@
out << pre.str();
}
+ auto* type = builder_.Sem().Get(var)->Type();
+
if (var->HasConstantIdDecoration()) {
auto const_id = var->constant_id();
@@ -2727,8 +2733,7 @@
}
out << "#endif" << std::endl;
out << "static const ";
- if (!EmitType(out, var->type(),
- builder_.Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
out << " " << builder_.Symbols().NameFor(var->symbol())
@@ -2736,11 +2741,10 @@
out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
} else {
out << "static const ";
- if (!EmitType(out, var->type(),
- builder_.Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
- if (!var->type()->Is<type::Array>()) {
+ if (!type->Is<type::Array>()) {
out << " " << builder_.Symbols().NameFor(var->symbol());
}
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 0f65366..f5941f0 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -412,6 +412,10 @@
TEST_F(HlslGeneratorImplTest_Binary, Decl_WithLogical) {
// var a : bool = (b && c) || d;
+ auto* b_decl = Decl(Var("b", ty.bool_(), ast::StorageClass::kFunction));
+ auto* c_decl = Decl(Var("c", ty.bool_(), ast::StorageClass::kFunction));
+ auto* d_decl = Decl(Var("d", ty.bool_(), ast::StorageClass::kFunction));
+
auto* b = Expr("b");
auto* c = Expr("c");
auto* d = Expr("d");
@@ -422,11 +426,12 @@
ast::BinaryOp::kLogicalOr,
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, b, c), d));
- auto* expr = create<ast::VariableDeclStatement>(var);
+ auto* decl = Decl(var);
+ WrapInFunction(b_decl, c_decl, d_decl, Decl(var));
GeneratorImpl& gen = Build();
- ASSERT_TRUE(gen.EmitStatement(out, expr)) << gen.error();
+ ASSERT_TRUE(gen.EmitStatement(out, decl)) << gen.error();
EXPECT_EQ(result(), R"(bool _tint_tmp = b;
if (_tint_tmp) {
_tint_tmp = c;
diff --git a/src/writer/hlsl/generator_impl_module_constant_test.cc b/src/writer/hlsl/generator_impl_module_constant_test.cc
index b3a197f..405b916 100644
--- a/src/writer/hlsl/generator_impl_module_constant_test.cc
+++ b/src/writer/hlsl/generator_impl_module_constant_test.cc
@@ -24,6 +24,7 @@
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_ModuleConstant) {
auto* var = Const("pos", ty.array<f32, 3>(), array<f32, 3>(1.f, 2.f, 3.f));
+ WrapInFunction(Decl(var));
GeneratorImpl& gen = Build();
@@ -36,6 +37,7 @@
ast::DecorationList{
create<ast::ConstantIdDecoration>(23),
});
+ WrapInFunction(Decl(var));
GeneratorImpl& gen = Build();
@@ -53,6 +55,7 @@
ast::DecorationList{
create<ast::ConstantIdDecoration>(23),
});
+ WrapInFunction(Decl(var));
GeneratorImpl& gen = Build();
diff --git a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
index 8ffce85..495763d 100644
--- a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
+++ b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
@@ -39,6 +39,7 @@
auto* var = Const("a", ty.f32());
auto* stmt = create<ast::VariableDeclStatement>(var);
+ WrapInFunction(stmt);
GeneratorImpl& gen = Build();
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index b3142a2..cf7e396 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -964,8 +964,8 @@
bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto* func_sem = program_->Sem().Get(func);
- std::vector<std::pair<ast::Variable*, uint32_t>> in_locations;
- std::vector<std::pair<ast::Variable*, ast::Decoration*>> out_variables;
+ std::vector<std::pair<const ast::Variable*, uint32_t>> in_locations;
+ std::vector<std::pair<const ast::Variable*, ast::Decoration*>> out_variables;
for (auto data : func_sem->ReferencedLocationVariables()) {
auto* var = data.first;
@@ -1003,7 +1003,8 @@
uint32_t loc = data.second;
make_indent();
- if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(program_->Sem().Get(var)->Type(),
+ program_->Symbols().NameFor(var->symbol()))) {
return false;
}
@@ -1039,7 +1040,8 @@
auto* deco = data.second;
make_indent();
- if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) {
+ if (!EmitType(program_->Sem().Get(var)->Type(),
+ program_->Symbols().NameFor(var->symbol()))) {
return false;
}
@@ -1252,7 +1254,7 @@
first = false;
out_ << "thread ";
- if (!EmitType(var->Declaration()->type(), "")) {
+ if (!EmitType(var->Type(), "")) {
return false;
}
out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol());
@@ -1267,7 +1269,7 @@
out_ << "constant ";
// TODO(dsinclair): Can arrays be uniform? If so, fix this ...
- if (!EmitType(var->Declaration()->type(), "")) {
+ if (!EmitType(var->Type(), "")) {
return false;
}
out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol());
@@ -1280,7 +1282,7 @@
}
first = false;
- auto* ac = var->Declaration()->type()->As<type::AccessControl>();
+ auto* ac = var->Type()->As<type::AccessControl>();
if (ac == nullptr) {
diagnostics_.add_error(
"invalid type for storage buffer, expected access control");
@@ -1303,11 +1305,13 @@
}
first = false;
- if (!EmitType(v->type(), program_->Symbols().NameFor(v->symbol()))) {
+ auto* type = program_->Sem().Get(v)->Type();
+
+ if (!EmitType(type, program_->Symbols().NameFor(v->symbol()))) {
return false;
}
// Array name is output as part of the type
- if (!v->type()->Is<type::Array>()) {
+ if (!type->Is<type::Array>()) {
out_ << " " << program_->Symbols().NameFor(v->symbol());
}
}
@@ -1394,13 +1398,15 @@
}
first = false;
- if (!EmitType(var->type(), "")) {
+ auto* type = program_->Sem().Get(var)->Type();
+
+ if (!EmitType(type, "")) {
return false;
}
out_ << " " << program_->Symbols().NameFor(var->symbol());
- if (var->type()->Is<type::Struct>()) {
+ if (type->Is<type::Struct>()) {
out_ << " [[stage_in]]";
} else {
auto& decos = var->decorations();
@@ -1440,7 +1446,7 @@
auto* builtin = data.second;
- if (!EmitType(var->Declaration()->type(), "")) {
+ if (!EmitType(var->Type(), "")) {
return false;
}
@@ -1475,7 +1481,7 @@
out_ << "constant ";
// TODO(dsinclair): Can you have a uniform array? If so, this needs to be
// updated to handle arrays property.
- if (!EmitType(var->Declaration()->type(), "")) {
+ if (!EmitType(var->Type(), "")) {
return false;
}
out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol())
@@ -1495,7 +1501,7 @@
auto* binding = data.second.binding;
// auto* set = data.second.set;
- auto* ac = var->Declaration()->type()->As<type::AccessControl>();
+ auto* ac = var->Type()->As<type::AccessControl>();
if (ac == nullptr) {
diagnostics_.add_error(
"invalid type for storage buffer, expected access control");
@@ -1640,7 +1646,7 @@
return false;
}
} else {
- if (!EmitZeroValue(var->type())) {
+ if (!EmitZeroValue(program_->Sem().Get(var)->Type())) {
return false;
}
}
@@ -2156,10 +2162,10 @@
if (decl->is_const()) {
out_ << "const ";
}
- if (!EmitType(decl->type(), program_->Symbols().NameFor(decl->symbol()))) {
+ if (!EmitType(var->Type(), program_->Symbols().NameFor(decl->symbol()))) {
return false;
}
- if (!decl->type()->Is<type::Array>()) {
+ if (!var->Type()->Is<type::Array>()) {
out_ << " " << program_->Symbols().NameFor(decl->symbol());
}
@@ -2173,7 +2179,7 @@
var->StorageClass() == ast::StorageClass::kFunction ||
var->StorageClass() == ast::StorageClass::kNone ||
var->StorageClass() == ast::StorageClass::kOutput) {
- if (!EmitZeroValue(decl->type())) {
+ if (!EmitZeroValue(var->Type())) {
return false;
}
}
@@ -2198,10 +2204,11 @@
}
out_ << "constant ";
- if (!EmitType(var->type(), program_->Symbols().NameFor(var->symbol()))) {
+ auto* type = program_->Sem().Get(var)->Type();
+ if (!EmitType(type, program_->Symbols().NameFor(var->symbol()))) {
return false;
}
- if (!var->type()->Is<type::Array>()) {
+ if (!type->Is<type::Array>()) {
out_ << " " << program_->Symbols().NameFor(var->symbol());
}
diff --git a/src/writer/msl/generator_impl_module_constant_test.cc b/src/writer/msl/generator_impl_module_constant_test.cc
index b17ce9d..c793214 100644
--- a/src/writer/msl/generator_impl_module_constant_test.cc
+++ b/src/writer/msl/generator_impl_module_constant_test.cc
@@ -24,6 +24,7 @@
TEST_F(MslGeneratorImplTest, Emit_ModuleConstant) {
auto* var = Const("pos", ty.array<f32, 3>(), array<f32, 3>(1.f, 2.f, 3.f));
+ WrapInFunction(Decl(var));
GeneratorImpl& gen = Build();
@@ -36,6 +37,7 @@
ast::DecorationList{
create<ast::ConstantIdDecoration>(23),
});
+ WrapInFunction(Decl(var));
GeneratorImpl& gen = Build();
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 440699a..96b5f65 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -538,7 +538,8 @@
auto param_op = result_op();
auto param_id = param_op.to_i();
- auto param_type_id = GenerateTypeIfNeeded(param->type());
+ auto param_type_id =
+ GenerateTypeIfNeeded(builder_.Sem().Get(param)->Type());
if (param_type_id == 0) {
return false;
}
@@ -592,7 +593,8 @@
OperandList ops = {func_op, Operand::Int(ret_id)};
for (auto* param : func->params()) {
- auto param_type_id = GenerateTypeIfNeeded(param->type());
+ auto param_type_id =
+ GenerateTypeIfNeeded(builder_.Sem().Get(param)->Type());
if (param_type_id == 0) {
return 0;
}
@@ -629,7 +631,8 @@
auto result = result_op();
auto var_id = result.to_i();
auto sc = ast::StorageClass::kFunction;
- type::Pointer pt(var->type(), sc);
+ auto* type = builder_.Sem().Get(var)->Type();
+ type::Pointer pt(type, sc);
auto type_id = GenerateTypeIfNeeded(&pt);
if (type_id == 0) {
return false;
@@ -641,7 +644,7 @@
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
- auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded());
+ auto null_id = GenerateConstantNullIfNeeded(type->UnwrapPtrIfNeeded());
if (null_id == 0) {
return 0;
}
@@ -704,7 +707,7 @@
? ast::StorageClass::kPrivate
: sem->StorageClass();
- type::Pointer pt(var->type(), sc);
+ type::Pointer pt(sem->Type(), sc);
auto type_id = GenerateTypeIfNeeded(&pt);
if (type_id == 0) {
return false;
@@ -719,11 +722,11 @@
// Unwrap after emitting the access control as unwrap all removes access
// control types.
- auto* type = var->type()->UnwrapAll();
+ auto* type_no_ac = sem->Type()->UnwrapAll();
if (var->has_constructor()) {
ops.push_back(Operand::Int(init_id));
- } else if (type->Is<type::Texture>()) {
- if (auto* ac = var->type()->As<type::AccessControl>()) {
+ } else if (type_no_ac->Is<type::Texture>()) {
+ if (auto* ac = sem->Type()->As<type::AccessControl>()) {
switch (ac->access_control()) {
case ast::AccessControl::kWriteOnly:
push_annot(
@@ -739,7 +742,7 @@
break;
}
}
- } else if (!type->Is<type::Sampler>()) {
+ } else if (!type_no_ac->Is<type::Sampler>()) {
// Certain cases require us to generate a constructor value.
//
// 1- ConstantId's must be attached to the OpConstant, if we have a
@@ -748,17 +751,17 @@
// 2- If we don't have a constructor and we're an Output or Private variable
// then WGSL requires an initializer.
if (var->HasConstantIdDecoration()) {
- if (type->Is<type::F32>()) {
- ast::FloatLiteral l(Source{}, type, 0.0f);
+ if (type_no_ac->Is<type::F32>()) {
+ ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
init_id = GenerateLiteralIfNeeded(var, &l);
- } else if (type->Is<type::U32>()) {
- ast::UintLiteral l(Source{}, type, 0);
+ } else if (type_no_ac->Is<type::U32>()) {
+ ast::UintLiteral l(Source{}, type_no_ac, 0);
init_id = GenerateLiteralIfNeeded(var, &l);
- } else if (type->Is<type::I32>()) {
- ast::SintLiteral l(Source{}, type, 0);
+ } else if (type_no_ac->Is<type::I32>()) {
+ ast::SintLiteral l(Source{}, type_no_ac, 0);
init_id = GenerateLiteralIfNeeded(var, &l);
- } else if (type->Is<type::Bool>()) {
- ast::BoolLiteral l(Source{}, type, false);
+ } else if (type_no_ac->Is<type::Bool>()) {
+ ast::BoolLiteral l(Source{}, type_no_ac, false);
init_id = GenerateLiteralIfNeeded(var, &l);
} else {
error_ = "invalid type for constant_id, must be scalar";
@@ -771,7 +774,7 @@
} else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
sem->StorageClass() == ast::StorageClass::kNone ||
sem->StorageClass() == ast::StorageClass::kOutput) {
- init_id = GenerateConstantNullIfNeeded(type);
+ init_id = GenerateConstantNullIfNeeded(type_no_ac);
if (init_id == 0) {
return 0;
}
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 6539d93..f894c4e 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -321,7 +321,7 @@
out_ << program_->Symbols().NameFor(v->symbol()) << " : ";
- if (!EmitType(v->type())) {
+ if (!EmitType(program_->Sem().Get(v)->Type())) {
return false;
}
}
@@ -578,13 +578,13 @@
out_ << "var";
if (sem->StorageClass() != ast::StorageClass::kNone &&
sem->StorageClass() != ast::StorageClass::kFunction &&
- !var->type()->UnwrapAll()->is_handle()) {
+ !sem->Type()->UnwrapAll()->is_handle()) {
out_ << "<" << sem->StorageClass() << ">";
}
}
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : ";
- if (!EmitType(var->type())) {
+ if (!EmitType(sem->Type())) {
return false;
}
diff --git a/src/writer/wgsl/generator_impl_variable_test.cc b/src/writer/wgsl/generator_impl_variable_test.cc
index b80d4bf..87660fd 100644
--- a/src/writer/wgsl/generator_impl_variable_test.cc
+++ b/src/writer/wgsl/generator_impl_variable_test.cc
@@ -75,23 +75,24 @@
}
TEST_F(WgslGeneratorImplTest, EmitVariable_Constructor) {
- auto* v =
- Global("a", ty.f32(), ast::StorageClass::kNone, Expr("initializer"));
+ auto* v = Global("a", ty.f32(), ast::StorageClass::kNone, Expr(1.0f));
+ WrapInFunction(Decl(v));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitVariable(v)) << gen.error();
- EXPECT_EQ(gen.result(), R"(var a : f32 = initializer;
+ EXPECT_EQ(gen.result(), R"(var a : f32 = 1.0;
)");
}
TEST_F(WgslGeneratorImplTest, EmitVariable_Const) {
- auto* v = Const("a", ty.f32(), Expr("initializer"));
+ auto* v = Const("a", ty.f32(), Expr(1.0f));
+ WrapInFunction(Decl(v));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitVariable(v)) << gen.error();
- EXPECT_EQ(gen.result(), R"(const a : f32 = initializer;
+ EXPECT_EQ(gen.result(), R"(const a : f32 = 1.0;
)");
}