Resolver: compute canonical types and store them as semantic::Variable::Type
We define the canonical type as a type stripped of all aliases. For
example, Canonical(alias<alias<vec3<alias<f32>>>>) is vec3<f32>. This
change adds Resolver::Canonical(Type*) which caches and returns the
resulting canonical type. We use this throughout the Resolver instead of
UnwrapAliasIfNeeded(), and we store the result in semantic::Variable,
returned from it's Type() member function.
Also:
* Wrote unit tests for Resolver::Canonical()
* Added semantic::Variable::DeclaredType() as a convenience to
retrieve the AST variable's type.
* Updated post-resolve code (transforms) to make use of Type and
DeclaredType appropriately, removing unnecessary calls to
UnwrapAliasIfNeeded.
* Added IntrinsicTableTest.MatchWithNestedAliasUnwrapping to ensure we
don't need to pass canonical parameter types for instrinsic table
lookups.
* ProgramBuilder: added vecN and matMxN overloads that take a Type* arg
to create them with alias types.
Bug: tint:705
Change-Id: I58a3b62538356b8dad2b1161a19b38bcefdd5d62
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47360
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/intrinsic_table_test.cc b/src/intrinsic_table_test.cc
index 1cf1868..350eec7 100644
--- a/src/intrinsic_table_test.cc
+++ b/src/intrinsic_table_test.cc
@@ -357,6 +357,23 @@
EXPECT_THAT(result.intrinsic->Parameters(), ElementsAre(Parameter{ty.f32()}));
}
+TEST_F(IntrinsicTableTest, MatchWithNestedAliasUnwrapping) {
+ auto* alias_a = ty.alias("alias_a", ty.bool_());
+ auto* alias_b = ty.alias("alias_b", alias_a);
+ auto* alias_c = ty.alias("alias_c", alias_b);
+ auto* vec4_of_c = ty.vec4(alias_c);
+ auto* alias_d = ty.alias("alias_d", vec4_of_c);
+ auto* alias_e = ty.alias("alias_e", alias_d);
+
+ auto result = table->Lookup(*this, IntrinsicType::kAll, {alias_e}, Source{});
+ ASSERT_NE(result.intrinsic, nullptr);
+ ASSERT_EQ(result.diagnostics.str(), "");
+ EXPECT_THAT(result.intrinsic->Type(), IntrinsicType::kAll);
+ EXPECT_THAT(result.intrinsic->ReturnType(), ty.bool_());
+ EXPECT_THAT(result.intrinsic->Parameters(),
+ ElementsAre(Parameter{ty.vec4<bool>()}));
+}
+
TEST_F(IntrinsicTableTest, MatchOpenType) {
auto result = table->Lookup(*this, IntrinsicType::kClamp,
{ty.f32(), ty.f32(), ty.f32()}, Source{});
diff --git a/src/program_builder.h b/src/program_builder.h
index 5fddb56..fb261cc 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -338,76 +338,148 @@
/// @returns a void type
type::Void* void_() const { return builder->create<type::Void>(); }
+ /// @param type vector subtype
+ /// @return the tint AST type for a 2-element vector of `type`.
+ type::Vector* vec2(type::Type* type) const {
+ return builder->create<type::Vector>(type, 2u);
+ }
+
+ /// @param type vector subtype
+ /// @return the tint AST type for a 3-element vector of `type`.
+ type::Vector* vec3(type::Type* type) const {
+ return builder->create<type::Vector>(type, 3u);
+ }
+
+ /// @param type vector subtype
+ /// @return the tint AST type for a 4-element vector of `type`.
+ type::Type* vec4(type::Type* type) const {
+ return builder->create<type::Vector>(type, 4u);
+ }
+
/// @return the tint AST type for a 2-element vector of the C type `T`.
template <typename T>
type::Vector* vec2() const {
- return builder->create<type::Vector>(Of<T>(), 2);
+ return vec2(Of<T>());
}
/// @return the tint AST type for a 3-element vector of the C type `T`.
template <typename T>
type::Vector* vec3() const {
- return builder->create<type::Vector>(Of<T>(), 3);
+ return vec3(Of<T>());
}
/// @return the tint AST type for a 4-element vector of the C type `T`.
template <typename T>
type::Type* vec4() const {
- return builder->create<type::Vector>(Of<T>(), 4);
+ return vec4(Of<T>());
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 2x3 matrix of `type`.
+ type::Matrix* mat2x2(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 2u, 2u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 2x3 matrix of `type`.
+ type::Matrix* mat2x3(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 3u, 2u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 2x4 matrix of `type`.
+ type::Matrix* mat2x4(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 4u, 2u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 3x2 matrix of `type`.
+ type::Matrix* mat3x2(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 2u, 3u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 3x3 matrix of `type`.
+ type::Matrix* mat3x3(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 3u, 3u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 3x4 matrix of `type`.
+ type::Matrix* mat3x4(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 4u, 3u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 4x2 matrix of `type`.
+ type::Matrix* mat4x2(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 2u, 4u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 4x3 matrix of `type`.
+ type::Matrix* mat4x3(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 3u, 4u);
+ }
+
+ /// @param type matrix subtype
+ /// @return the tint AST type for a 4x4 matrix of `type`.
+ type::Matrix* mat4x4(type::Type* type) const {
+ return builder->create<type::Matrix>(type, 4u, 4u);
}
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x2() const {
- return builder->create<type::Matrix>(Of<T>(), 2, 2);
+ return mat2x2(Of<T>());
}
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x3() const {
- return builder->create<type::Matrix>(Of<T>(), 3, 2);
+ return mat2x3(Of<T>());
}
/// @return the tint AST type for a 2x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x4() const {
- return builder->create<type::Matrix>(Of<T>(), 4, 2);
+ return mat2x4(Of<T>());
}
/// @return the tint AST type for a 3x2 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x2() const {
- return builder->create<type::Matrix>(Of<T>(), 2, 3);
+ return mat3x2(Of<T>());
}
/// @return the tint AST type for a 3x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x3() const {
- return builder->create<type::Matrix>(Of<T>(), 3, 3);
+ return mat3x3(Of<T>());
}
/// @return the tint AST type for a 3x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x4() const {
- return builder->create<type::Matrix>(Of<T>(), 4, 3);
+ return mat3x4(Of<T>());
}
/// @return the tint AST type for a 4x2 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x2() const {
- return builder->create<type::Matrix>(Of<T>(), 2, 4);
+ return mat4x2(Of<T>());
}
/// @return the tint AST type for a 4x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x3() const {
- return builder->create<type::Matrix>(Of<T>(), 3, 4);
+ return mat4x3(Of<T>());
}
/// @return the tint AST type for a 4x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x4() const {
- return builder->create<type::Matrix>(Of<T>(), 4, 4);
+ return mat4x4(Of<T>());
}
/// @param subtype the array element type
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 06461e5..a18b481 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -42,6 +42,7 @@
#include "src/semantic/struct.h"
#include "src/semantic/variable.h"
#include "src/type/access_control_type.h"
+#include "src/utils/get_or_create.h"
#include "src/utils/math.h"
namespace tint {
@@ -430,7 +431,7 @@
}
// Check that we saw a pipeline IO attribute iff we need one.
- if (ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+ if (Canonical(ty)->Is<type::Struct>()) {
if (pipeline_io_attribute) {
diagnostics_.add_error(
"entry point IO attributes must not be used on structure " +
@@ -466,11 +467,11 @@
return false;
}
- if (auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+ if (auto* struct_ty = Canonical(ty)->As<type::Struct>()) {
// Validate the decorations for each struct members, and also check for
// invalid member types.
for (auto* member : struct_ty->impl()->members()) {
- auto* member_ty = member->type()->UnwrapAliasIfNeeded();
+ auto* member_ty = Canonical(member->type());
if (member_ty->Is<type::Struct>()) {
diagnostics_.add_error(
"entry point IO types cannot contain nested structures",
@@ -547,8 +548,7 @@
return false;
}
- if (auto* str =
- param->declared_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+ if (auto* str = param_info->type->As<type::Struct>()) {
auto* info = Structure(str);
if (!info) {
return false;
@@ -572,8 +572,7 @@
}
}
- if (auto* str =
- func->return_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+ if (auto* str = Canonical(func->return_type())->As<type::Struct>()) {
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
func->source())) {
diagnostics_.add_note("while instantiating return type for " +
@@ -1288,15 +1287,16 @@
using Matrix = type::Matrix;
using Vector = type::Vector;
- auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
- auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
+ auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
+ auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
+
+ auto* lhs_type = Canonical(lhs_declared_type);
+ auto* rhs_type = Canonical(rhs_declared_type);
auto* lhs_vec = lhs_type->As<Vector>();
- auto* lhs_vec_elem_type =
- lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
+ auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_type->As<Vector>();
- auto* rhs_vec_elem_type =
- rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
+ auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_vec_elem_types =
lhs_vec_elem_type && rhs_vec_elem_type &&
@@ -1348,11 +1348,9 @@
}
auto* lhs_mat = lhs_type->As<Matrix>();
- auto* lhs_mat_elem_type =
- lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
+ auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
- auto* rhs_mat_elem_type =
- rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
+ auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
@@ -1438,9 +1436,9 @@
diagnostics_.add_error(
"Binary expression operand types are invalid for this operation: " +
- lhs_type->FriendlyName(builder_->Symbols()) + " " +
+ lhs_declared_type->FriendlyName(builder_->Symbols()) + " " +
FriendlyName(expr->op()) + " " +
- rhs_type->FriendlyName(builder_->Symbols()),
+ rhs_declared_type->FriendlyName(builder_->Symbols()),
expr->source());
return false;
}
@@ -1600,7 +1598,7 @@
}
Resolver::VariableInfo* Resolver::CreateVariableInfo(ast::Variable* var) {
- auto* info = variable_infos_.Create(var);
+ auto* info = variable_infos_.Create(var, Canonical(var->declared_type()));
variable_to_info_.emplace(var, info);
return info;
}
@@ -1748,13 +1746,13 @@
/*vec4*/ 16,
};
- ty = ty->UnwrapAliasIfNeeded();
- if (ty->is_scalar()) {
+ auto* cty = Canonical(ty);
+ if (cty->is_scalar()) {
// Note: Also captures booleans, but these are not host-shareable.
align = 4;
size = 4;
return true;
- } else if (auto* vec = ty->As<type::Vector>()) {
+ } else if (auto* vec = cty->As<type::Vector>()) {
if (vec->size() < 2 || vec->size() > 4) {
TINT_UNREACHABLE(diagnostics_)
<< "Invalid vector size: vec" << vec->size();
@@ -1763,7 +1761,7 @@
align = vector_align[vec->size()];
size = vector_size[vec->size()];
return true;
- } else if (auto* mat = ty->As<type::Matrix>()) {
+ } else if (auto* mat = cty->As<type::Matrix>()) {
if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
mat->rows() > 4) {
TINT_UNREACHABLE(diagnostics_)
@@ -1773,15 +1771,15 @@
align = vector_align[mat->rows()];
size = vector_align[mat->rows()] * mat->columns();
return true;
- } else if (auto* s = ty->As<type::Struct>()) {
+ } else if (auto* s = cty->As<type::Struct>()) {
if (auto* si = Structure(s)) {
align = si->align;
size = si->size;
return true;
}
return false;
- } else if (auto* arr = ty->As<type::Array>()) {
- if (auto* sem = Array(arr)) {
+ } else if (cty->Is<type::Array>()) {
+ if (auto* sem = Array(ty->UnwrapAliasIfNeeded()->As<type::Array>())) {
align = sem->Align();
size = sem->Size();
return true;
@@ -2249,9 +2247,37 @@
return vec_type.FriendlyName(builder_->Symbols());
}
-Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
+type::Type* Resolver::Canonical(type::Type* type) {
+ using Type = type::Type;
+ using Alias = type::Alias;
+ using Matrix = type::Matrix;
+ using Vector = type::Vector;
+
+ std::function<Type*(Type*)> make_canonical;
+ make_canonical = [&](Type* t) -> type::Type* {
+ // Unwrap alias sequence
+ Type* ct = t;
+ while (auto* p = ct->As<Alias>()) {
+ ct = p->type();
+ }
+
+ if (auto* v = ct->As<Vector>()) {
+ return builder_->create<Vector>(make_canonical(v->type()), v->size());
+ }
+ if (auto* m = ct->As<Matrix>()) {
+ return builder_->create<Matrix>(make_canonical(m->type()), m->rows(),
+ m->columns());
+ }
+ return ct;
+ };
+
+ return utils::GetOrCreate(type_to_canonical_, type,
+ [&] { return make_canonical(type); });
+}
+
+Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
: declaration(decl),
- type(decl->declared_type()),
+ type(ctype),
storage_class(decl->declared_storage_class()) {}
Resolver::VariableInfo::~VariableInfo() = default;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index e7f1f76..d4f73fc 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -86,11 +86,17 @@
/// structure member or array element of type `lhs`
static bool IsValidAssignment(type::Type* lhs, type::Type* rhs);
+ /// @param type the input type
+ /// @returns the canonical type for `type`; that is, a type with all aliases
+ /// removed. For example, `Canonical(alias<alias<vec3<alias<f32>>>>)` is
+ /// `vec3<f32>`.
+ type::Type* Canonical(type::Type* type);
+
private:
/// Structure holding semantic information about a variable.
/// Used to build the semantic::Variable nodes at the end of resolving.
struct VariableInfo {
- explicit VariableInfo(ast::Variable* decl);
+ VariableInfo(ast::Variable* decl, type::Type* type);
~VariableInfo();
ast::Variable* const declaration;
@@ -306,6 +312,7 @@
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<type::Struct*, StructInfo*> struct_info_;
+ std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
FunctionInfo* current_function_ = nullptr;
semantic::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index cbcb2a6..cd9dd7a 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -119,18 +119,30 @@
return ty.f32();
}
+using create_type_func_ptr =
+ type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
+
template <typename T>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
+template <create_type_func_ptr create_type>
+type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
+ auto* type = create_type(ty);
+ return ty.vec3(type);
+}
+
template <typename T>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
-using create_type_func_ptr =
- type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
+template <create_type_func_ptr create_type>
+type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
+ auto* type = create_type(ty);
+ return ty.mat3x3(type);
+}
template <create_type_func_ptr create_type>
type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index ec2113c..973c718 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
namespace tint {
+namespace resolver {
namespace {
class ResolverTypeValidationTest : public resolver::TestHelper,
@@ -463,5 +464,51 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+namespace GetCanonicalTests {
+struct Params {
+ create_type_func_ptr create_type;
+ create_type_func_ptr create_canonical_type;
+};
+
+static constexpr Params cases[] = {
+ Params{ty_bool_, ty_bool_},
+ Params{ty_alias<ty_bool_>, ty_bool_},
+ Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
+
+ Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
+ Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
+ Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
+
+ Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
+ Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
+ Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
+ Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
+ ty_vec3<ty_f32>},
+
+ Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
+ Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
+ Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
+ Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
+ ty_mat3x3<ty_f32>},
+};
+
+using CanonicalTest = ResolverTestWithParam<Params>;
+TEST_P(CanonicalTest, All) {
+ auto& params = GetParam();
+
+ auto* type = params.create_type(ty);
+ auto* expected_canonical_type = params.create_canonical_type(ty);
+
+ auto* canonical_type = r()->Canonical(type);
+
+ EXPECT_EQ(canonical_type, expected_canonical_type);
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ CanonicalTest,
+ testing::ValuesIn(cases));
+
+} // namespace GetCanonicalTests
+
} // namespace
+} // namespace resolver
} // namespace tint
diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc
index 13b3820..c345154 100644
--- a/src/semantic/sem_variable.cc
+++ b/src/semantic/sem_variable.cc
@@ -15,6 +15,7 @@
#include "src/semantic/variable.h"
#include "src/ast/identifier_expression.h"
+#include "src/ast/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable);
TINT_INSTANTIATE_TYPEINFO(tint::semantic::VariableUser);
@@ -29,6 +30,10 @@
Variable::~Variable() = default;
+type::Type* Variable::DeclaredType() const {
+ return declaration_->declared_type();
+}
+
VariableUser::VariableUser(ast::IdentifierExpression* declaration,
type::Type* type,
Statement* statement,
diff --git a/src/semantic/variable.h b/src/semantic/variable.h
index 7cae490..9873176 100644
--- a/src/semantic/variable.h
+++ b/src/semantic/variable.h
@@ -52,9 +52,12 @@
/// @returns the AST declaration node
const ast::Variable* Declaration() const { return declaration_; }
- /// @returns the type for the variable
+ /// @returns the canonical type for the variable
type::Type* Type() const { return type_; }
+ /// @returns the AST node's type. May be nullptr.
+ type::Type* DeclaredType() const;
+
/// @returns the storage class for the variable
ast::StorageClass StorageClass() const { return storage_class_; }
diff --git a/src/transform/binding_remapper.cc b/src/transform/binding_remapper.cc
index 4c8ed7b..aed0b0a 100644
--- a/src/transform/binding_remapper.cc
+++ b/src/transform/binding_remapper.cc
@@ -65,8 +65,7 @@
auto ac_it = remappings->access_controls.find(from);
if (ac_it != remappings->access_controls.end()) {
ast::AccessControl ac = ac_it->second;
- auto* var_ty = in->Sem().Get(var)->Type();
- auto* ty = var_ty->UnwrapAliasIfNeeded();
+ auto* ty = in->Sem().Get(var)->Type();
type::Type* inner_ty = nullptr;
if (auto* old_ac = ty->As<type::AccessControl>()) {
inner_ty = ctx.Clone(old_ac->type());
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index db1b5a7..72c88b4 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -72,11 +72,11 @@
for (auto* param : func->params()) {
auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
+ auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
ast::Expression* func_const_initializer = nullptr;
- if (auto* struct_ty =
- param_ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
+ if (auto* struct_ty = param_ty->As<type::Struct>()) {
// Pull out all struct members and build initializer list.
ast::ExpressionList init_values;
for (auto* member : struct_ty->impl()->members()) {
@@ -97,7 +97,7 @@
}
func_const_initializer =
- ctx.dst->Construct(ctx.Clone(param_ty), init_values);
+ ctx.dst->Construct(ctx.Clone(param_declared_ty), init_values);
} else {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, param->decorations(), [](const ast::Decoration* deco) {
@@ -105,7 +105,7 @@
ast::LocationDecoration>();
});
new_struct_members.push_back(ctx.dst->Member(
- param_name, ctx.Clone(param_ty), new_decorations));
+ param_name, ctx.Clone(param_declared_ty), new_decorations));
func_const_initializer =
ctx.dst->MemberAccessor(new_struct_param_symbol, param_name);
}
@@ -117,8 +117,8 @@
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter.
- auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
- func_const_initializer);
+ auto* func_const = ctx.dst->Const(
+ param_name, ctx.Clone(param_declared_ty), func_const_initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 6945cbf..6adc931 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -141,7 +141,8 @@
for (auto* param : func->params()) {
Symbol new_var = HoistToInputVariables(
- ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations());
+ ctx, func, ctx.src->Sem().Get(param)->Type(),
+ ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
@@ -153,9 +154,9 @@
if (!func->return_type()->Is<type::Void>()) {
ast::StatementList stores;
auto store_value_symbol = ctx.dst->Symbols().New();
- HoistToOutputVariables(ctx, func, func->return_type(),
- func->return_type_decorations(), {},
- store_value_symbol, stores);
+ HoistToOutputVariables(
+ ctx, func, func->return_type(), func->return_type(),
+ func->return_type_decorations(), {}, store_value_symbol, stores);
// Create a function that writes a return value to all output variables.
auto* store_value =
@@ -251,8 +252,9 @@
CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
+ type::Type* declared_ty,
const ast::DecorationList& decorations) const {
- if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+ if (!ty->Is<type::Struct>()) {
// Base case: create a global variable and return.
ast::DecorationList new_decorations =
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@@ -261,7 +263,7 @@
});
auto global_var_symbol = ctx.dst->Symbols().New();
auto* global_var =
- ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+ ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kInput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
return global_var_symbol;
@@ -269,10 +271,10 @@
// Recurse into struct members and build the initializer list.
ast::ExpressionList init_values;
- auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+ auto* struct_ty = ty->As<type::Struct>();
for (auto* member : struct_ty->impl()->members()) {
- auto member_var =
- HoistToInputVariables(ctx, func, member->type(), member->decorations());
+ auto member_var = HoistToInputVariables(
+ ctx, func, member->type(), member->type(), member->decorations());
init_values.push_back(ctx.dst->Expr(member_var));
}
@@ -283,8 +285,9 @@
}
// Create a function-scope variable for the struct.
- auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
- auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
+ auto* initializer = ctx.dst->Construct(ctx.Clone(declared_ty), init_values);
+ auto* func_var =
+ ctx.dst->Const(func_var_symbol, ctx.Clone(declared_ty), initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_var));
return func_var_symbol;
@@ -293,12 +296,13 @@
void Spirv::HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
+ type::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
ast::StatementList& stores) const {
// Base case.
- if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+ if (!ty->Is<type::Struct>()) {
// Create a global variable.
ast::DecorationList new_decorations =
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@@ -307,7 +311,7 @@
});
auto global_var_symbol = ctx.dst->Symbols().New();
auto* global_var =
- ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+ ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kOutput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
@@ -322,11 +326,12 @@
}
// Recurse into struct members.
- auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+ auto* struct_ty = ty->As<type::Struct>();
for (auto* member : struct_ty->impl()->members()) {
member_accesses.push_back(ctx.Clone(member->symbol()));
- HoistToOutputVariables(ctx, func, member->type(), member->decorations(),
- member_accesses, store_value, stores);
+ HoistToOutputVariables(ctx, func, member->type(), member->type(),
+ member->decorations(), member_accesses, store_value,
+ stores);
member_accesses.pop_back();
}
}
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index 7ac53e1..a80a1dc 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -60,6 +60,7 @@
Symbol HoistToInputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
+ type::Type* declared_ty,
const ast::DecorationList& decorations) const;
/// Recursively create module-scope output variables for `ty` and build a list
@@ -73,6 +74,7 @@
void HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
+ type::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 2c35689..81ae49a 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1894,7 +1894,7 @@
for (auto* var : func_sem->ReferencedModuleVariables()) {
auto* decl = var->Declaration();
- auto* unwrapped_type = var->Type()->UnwrapAll();
+ auto* unwrapped_type = var->DeclaredType()->UnwrapAll();
if (!emitted_globals.emplace(decl->symbol()).second) {
continue; // Global already emitted
}
@@ -1905,7 +1905,7 @@
continue; // Not interested in this type
}
- if (!EmitType(out, var->Type(), var->StorageClass(), "")) {
+ if (!EmitType(out, var->DeclaredType(), var->StorageClass(), "")) {
return false;
}
out << " " << builder_.Symbols().NameFor(decl->symbol());
@@ -1915,9 +1915,7 @@
if (unwrapped_type->Is<type::Texture>()) {
register_space = "t";
if (unwrapped_type->Is<type::StorageTexture>()) {
- if (auto* ac = var->Type()
- ->UnwrapAliasIfNeeded()
- ->As<type::AccessControl>()) {
+ if (auto* ac = var->Type()->As<type::AccessControl>()) {
if (!ac->IsReadOnly()) {
register_space = "u";
}
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 22285c0..78242ec 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -325,7 +325,7 @@
out_ << program_->Symbols().NameFor(v->symbol()) << " : ";
- if (!EmitType(program_->Sem().Get(v)->Type())) {
+ if (!EmitType(program_->Sem().Get(v)->DeclaredType())) {
return false;
}
}
@@ -599,7 +599,7 @@
}
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : ";
- if (!EmitType(sem->Type())) {
+ if (!EmitType(sem->DeclaredType())) {
return false;
}