Migrate to using semantic::Expression
Remove the mutable `result_type` from the ast::Expression.
Replace this with the use of semantic::Expression.
Bug: tint:390
Change-Id: I1f0eaf0dce8fde46fefe50bf2c5fe5b2e4d2d2df
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/39007
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/expression.cc b/src/ast/expression.cc
index 99eb1f9..27745cf 100644
--- a/src/ast/expression.cc
+++ b/src/ast/expression.cc
@@ -14,6 +14,9 @@
#include "src/ast/expression.h"
+#include "src/semantic/expression.h"
+#include "src/semantic/info.h"
+
TINT_INSTANTIATE_CLASS_ID(tint::ast::Expression);
namespace tint {
@@ -25,9 +28,9 @@
Expression::~Expression() = default;
-void Expression::set_result_type(type::Type* type) {
- // The expression result should never be an alias or access-controlled type
- result_type_ = type->UnwrapIfNeeded();
+std::string Expression::result_type_str(const semantic::Info& sem) const {
+ auto* sem_expr = sem.Get(this);
+ return sem_expr ? sem_expr->Type()->type_name() : "not set";
}
} // namespace ast
diff --git a/src/ast/expression.h b/src/ast/expression.h
index f46d461..62a9a72 100644
--- a/src/ast/expression.h
+++ b/src/ast/expression.h
@@ -30,18 +30,6 @@
public:
~Expression() override;
- /// Sets the resulting type of this expression
- /// @param type the result type to set
- void set_result_type(type::Type* type);
- /// @returns the resulting type from this expression
- type::Type* result_type() const { return result_type_; }
-
- /// @returns a string representation of the result type or 'not set' if no
- /// result type present
- std::string result_type_str(const semantic::Info&) const {
- return result_type_ ? result_type_->type_name() : "not set";
- }
-
protected:
/// Constructor
/// @param source the source of the expression
@@ -49,10 +37,13 @@
/// Move constructor
Expression(Expression&&);
+ /// @param sem the semantic info for the program
+ /// @returns a string representation of the result type or 'not set' if no
+ /// result type present
+ std::string result_type_str(const semantic::Info& sem) const;
+
private:
Expression(const Expression&) = delete;
-
- type::Type* result_type_ = nullptr; // Semantic info
};
/// A list of expressions
diff --git a/src/program.cc b/src/program.cc
index d5ff433..8b7b4bc 100644
--- a/src/program.cc
+++ b/src/program.cc
@@ -21,6 +21,7 @@
#include "src/clone_context.h"
#include "src/demangler.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type_determiner.h"
namespace tint {
@@ -102,6 +103,11 @@
return is_valid_;
}
+type::Type* Program::TypeOf(ast::Expression* expr) const {
+ auto* sem = Sem().Get(expr);
+ return sem ? sem->Type() : nullptr;
+}
+
std::string Program::to_str(bool demangle) const {
AssertNotMoved();
auto str = ast_->to_str(Sem());
diff --git a/src/program.h b/src/program.h
index 7388693..8954b7f 100644
--- a/src/program.h
+++ b/src/program.h
@@ -115,6 +115,12 @@
/// information
bool IsValid() const;
+ /// Helper for returning the resolved semantic type of the expression `expr`.
+ /// @param expr the AST expression
+ /// @return the resolved semantic type for the expression, or nullptr if the
+ /// expression has no resolved type.
+ type::Type* TypeOf(ast::Expression* expr) const;
+
/// @param demangle whether to automatically demangle the symbols in the
/// returned string
/// @returns a string describing this program.
diff --git a/src/program_builder.cc b/src/program_builder.cc
index ddd146a..8286958 100644
--- a/src/program_builder.cc
+++ b/src/program_builder.cc
@@ -20,6 +20,7 @@
#include "src/clone_context.h"
#include "src/demangler.h"
+#include "src/semantic/expression.h"
#include "src/type/struct_type.h"
namespace tint {
@@ -82,6 +83,11 @@
assert(!moved_);
}
+type::Type* ProgramBuilder::TypeOf(ast::Expression* expr) const {
+ auto* sem = Sem().Get(expr);
+ return sem ? sem->Type() : nullptr;
+}
+
ProgramBuilder::TypesBuilder::TypesBuilder(ProgramBuilder* pb) : builder(pb) {}
ast::Variable* ProgramBuilder::Var(const std::string& name,
diff --git a/src/program_builder.h b/src/program_builder.h
index 2fd6377..b683ac9 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -937,6 +937,15 @@
source_ = Source(loc);
}
+ /// Helper for returning the resolved semantic type of the expression `expr`.
+ /// @note As the TypeDeterminator is run when the Program is built, this will
+ /// only be useful for the TypeDeterminer itself and tests that use their own
+ /// TypeDeterminer.
+ /// @param expr the AST expression
+ /// @return the resolved semantic type for the expression, or nullptr if the
+ /// expression has no resolved type.
+ type::Type* TypeOf(ast::Expression* expr) const;
+
/// The builder types
TypesBuilder ty;
diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc
index dd293ff..b17a181 100644
--- a/src/transform/bound_array_accessors.cc
+++ b/src/transform/bound_array_accessors.cc
@@ -44,6 +44,7 @@
#include "src/ast/variable_decl_statement.h"
#include "src/clone_context.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type/array_type.h"
#include "src/type/matrix_type.h"
#include "src/type/u32_type.h"
@@ -70,7 +71,7 @@
ast::ArrayAccessorExpression* expr,
CloneContext* ctx,
diag::List* diags) {
- auto* ret_type = expr->array()->result_type()->UnwrapAll();
+ auto* ret_type = ctx->src->Sem().Get(expr->array())->Type()->UnwrapAll();
if (!ret_type->Is<type::Array>() && !ret_type->Is<type::Matrix>() &&
!ret_type->Is<type::Vector>()) {
return nullptr;
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index bf5e601..4256651 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -43,6 +43,7 @@
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type/array_type.h"
#include "src/type/bool_type.h"
#include "src/type/depth_texture_type.h"
@@ -308,6 +309,10 @@
return true;
}
+ if (TypeOf(expr)) {
+ return true; // Already resolved
+ }
+
if (auto* a = expr->As<ast::ArrayAccessorExpression>()) {
return DetermineArrayAccessor(a);
}
@@ -346,7 +351,7 @@
return false;
}
- auto* res = expr->array()->result_type();
+ auto* res = TypeOf(expr->array());
auto* parent_type = res->UnwrapAll();
type::Type* ret = nullptr;
if (auto* arr = parent_type->As<type::Array>()) {
@@ -373,7 +378,7 @@
ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction);
}
}
- expr->set_result_type(ret);
+ SetType(expr, ret);
return true;
}
@@ -382,7 +387,7 @@
if (!DetermineResultType(expr->expr())) {
return false;
}
- expr->set_result_type(expr->type());
+ SetType(expr, expr->type());
return true;
}
@@ -420,12 +425,6 @@
set_referenced_from_function_if_needed(var, false);
}
}
-
- // An identifier with a single name is a function call, not an import
- // lookup which we can handle with the regular identifier lookup.
- if (!DetermineResultType(ident)) {
- return false;
- }
}
} else {
if (!DetermineResultType(expr->func())) {
@@ -433,7 +432,9 @@
}
}
- if (!expr->func()->result_type()) {
+ if (auto* type = TypeOf(expr->func())) {
+ SetType(expr, type);
+ } else {
auto func_sym = expr->func()->As<ast::IdentifierExpression>()->symbol();
set_error(expr->source(),
"v-0005: function must be declared before use: '" +
@@ -441,7 +442,6 @@
return false;
}
- expr->set_result_type(expr->func()->result_type());
return true;
}
@@ -530,17 +530,17 @@
}
// The result type must be the same as the type of the parameter.
- auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
- expr->func()->set_result_type(param_type);
+ auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
+ SetType(expr->func(), param_type);
return true;
}
if (ident->intrinsic() == ast::Intrinsic::kAny ||
ident->intrinsic() == ast::Intrinsic::kAll) {
- expr->func()->set_result_type(builder_->create<type::Bool>());
+ SetType(expr->func(), builder_->create<type::Bool>());
return true;
}
if (ident->intrinsic() == ast::Intrinsic::kArrayLength) {
- expr->func()->set_result_type(builder_->create<type::U32>());
+ SetType(expr->func(), builder_->create<type::U32>());
return true;
}
if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) {
@@ -553,12 +553,12 @@
auto* bool_type = builder_->create<type::Bool>();
- auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
+ auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
if (auto* vec = param_type->As<type::Vector>()) {
- expr->func()->set_result_type(
- builder_->create<type::Vector>(bool_type, vec->size()));
+ SetType(expr->func(),
+ builder_->create<type::Vector>(bool_type, vec->size()));
} else {
- expr->func()->set_result_type(bool_type);
+ SetType(expr->func(), bool_type);
}
return true;
}
@@ -566,14 +566,14 @@
ast::intrinsic::TextureSignature::Parameters param;
auto* texture_param = expr->params()[0];
- if (!texture_param->result_type()->UnwrapAll()->Is<type::Texture>()) {
+ if (!TypeOf(texture_param)->UnwrapAll()->Is<type::Texture>()) {
set_error(expr->source(),
"invalid first argument for " +
builder_->Symbols().NameFor(ident->symbol()));
return false;
}
type::Texture* texture =
- texture_param->result_type()->UnwrapAll()->As<type::Texture>();
+ TypeOf(texture_param)->UnwrapAll()->As<type::Texture>();
bool is_array = type::IsTextureArray(texture->dim());
bool is_multisampled = texture->Is<type::MultisampledTexture>();
@@ -744,12 +744,12 @@
}
}
}
- expr->func()->set_result_type(return_type);
+ SetType(expr->func(), return_type);
return true;
}
if (ident->intrinsic() == ast::Intrinsic::kDot) {
- expr->func()->set_result_type(builder_->create<type::F32>());
+ SetType(expr->func(), builder_->create<type::F32>());
return true;
}
if (ident->intrinsic() == ast::Intrinsic::kSelect) {
@@ -762,8 +762,8 @@
}
// The result type must be the same as the type of the parameter.
- auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
- expr->func()->set_result_type(param_type);
+ auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
+ SetType(expr->func(), param_type);
return true;
}
@@ -791,8 +791,7 @@
std::vector<type::Type*> result_types;
for (uint32_t i = 0; i < data->param_count; ++i) {
- result_types.push_back(
- expr->params()[i]->result_type()->UnwrapPtrIfNeeded());
+ result_types.push_back(TypeOf(expr->params()[i])->UnwrapPtrIfNeeded());
switch (data->data_type) {
case IntrinsicDataType::kFloatOrIntScalarOrVector:
@@ -869,18 +868,17 @@
// provided.
if (ident->intrinsic() == ast::Intrinsic::kLength ||
ident->intrinsic() == ast::Intrinsic::kDistance) {
- expr->func()->set_result_type(
- result_types[0]->is_float_scalar()
- ? result_types[0]
- : result_types[0]->As<type::Vector>()->type());
+ SetType(expr->func(), result_types[0]->is_float_scalar()
+ ? result_types[0]
+ : result_types[0]->As<type::Vector>()->type());
return true;
}
// The determinant returns the component type of the columns
if (ident->intrinsic() == ast::Intrinsic::kDeterminant) {
- expr->func()->set_result_type(result_types[0]->As<type::Matrix>()->type());
+ SetType(expr->func(), result_types[0]->As<type::Matrix>()->type());
return true;
}
- expr->func()->set_result_type(result_types[0]);
+ SetType(expr->func(), result_types[0]);
return true;
}
@@ -891,10 +889,10 @@
return false;
}
}
- expr->set_result_type(ty->type());
+ SetType(expr, ty->type());
} else {
- expr->set_result_type(
- expr->As<ast::ScalarConstructorExpression>()->literal()->type());
+ SetType(expr,
+ expr->As<ast::ScalarConstructorExpression>()->literal()->type());
}
return true;
}
@@ -906,12 +904,12 @@
// A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type.
if (var->is_const()) {
- expr->set_result_type(var->type());
+ SetType(expr, var->type());
} else if (var->type()->Is<type::Pointer>()) {
- expr->set_result_type(var->type());
+ SetType(expr, var->type());
} else {
- expr->set_result_type(
- builder_->create<type::Pointer>(var->type(), var->storage_class()));
+ SetType(expr, builder_->create<type::Pointer>(var->type(),
+ var->storage_class()));
}
set_referenced_from_function_if_needed(var, true);
@@ -920,7 +918,7 @@
auto iter = symbol_to_function_.find(symbol);
if (iter != symbol_to_function_.end()) {
- expr->set_result_type(iter->second->return_type());
+ SetType(expr, iter->second->return_type());
return true;
}
@@ -1091,7 +1089,7 @@
return false;
}
- auto* res = expr->structure()->result_type();
+ auto* res = TypeOf(expr->structure());
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
type::Type* ret = nullptr;
@@ -1143,7 +1141,7 @@
return false;
}
- expr->set_result_type(ret);
+ SetType(expr, ret);
return true;
}
@@ -1157,7 +1155,7 @@
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
expr->IsDivide() || expr->IsModulo()) {
- expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded());
+ SetType(expr, TypeOf(expr->lhs())->UnwrapPtrIfNeeded());
return true;
}
// Result type is a scalar or vector of boolean type
@@ -1165,18 +1163,17 @@
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = builder_->create<type::Bool>();
- auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
+ auto* param_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
+ type::Type* result_type = bool_type;
if (auto* vec = param_type->As<type::Vector>()) {
- expr->set_result_type(
- builder_->create<type::Vector>(bool_type, vec->size()));
- } else {
- expr->set_result_type(bool_type);
+ result_type = builder_->create<type::Vector>(bool_type, vec->size());
}
+ SetType(expr, result_type);
return true;
}
if (expr->IsMultiply()) {
- auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
- auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
+ auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
+ auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
// Note, the ordering here matters. The later checks depend on the prior
// checks having been done.
@@ -1184,34 +1181,36 @@
auto* rhs_mat = rhs_type->As<type::Matrix>();
auto* lhs_vec = lhs_type->As<type::Vector>();
auto* rhs_vec = rhs_type->As<type::Vector>();
+ type::Type* result_type;
if (lhs_mat && rhs_mat) {
- expr->set_result_type(builder_->create<type::Matrix>(
- lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns()));
+ result_type = builder_->create<type::Matrix>(
+ lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns());
} else if (lhs_mat && rhs_vec) {
- expr->set_result_type(
- builder_->create<type::Vector>(lhs_mat->type(), lhs_mat->rows()));
+ result_type =
+ builder_->create<type::Vector>(lhs_mat->type(), lhs_mat->rows());
} else if (lhs_vec && rhs_mat) {
- expr->set_result_type(
- builder_->create<type::Vector>(rhs_mat->type(), rhs_mat->columns()));
+ result_type =
+ builder_->create<type::Vector>(rhs_mat->type(), rhs_mat->columns());
} else if (lhs_mat) {
// matrix * scalar
- expr->set_result_type(lhs_type);
+ result_type = lhs_type;
} else if (rhs_mat) {
// scalar * matrix
- expr->set_result_type(rhs_type);
+ result_type = rhs_type;
} else if (lhs_vec && rhs_vec) {
- expr->set_result_type(lhs_type);
+ result_type = lhs_type;
} else if (lhs_vec) {
// Vector * scalar
- expr->set_result_type(lhs_type);
+ result_type = lhs_type;
} else if (rhs_vec) {
// Scalar * vector
- expr->set_result_type(rhs_type);
+ result_type = rhs_type;
} else {
// Scalar * Scalar
- expr->set_result_type(lhs_type);
+ result_type = lhs_type;
}
+ SetType(expr, result_type);
return true;
}
@@ -1224,7 +1223,9 @@
if (!DetermineResultType(expr->expr())) {
return false;
}
- expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded());
+
+ auto* result_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded();
+ SetType(expr, result_type);
return true;
}
@@ -1288,4 +1289,9 @@
return false;
}
+void TypeDeterminer::SetType(ast::Expression* expr, type::Type* type) const {
+ return builder_->Sem().Add(expr,
+ builder_->create<semantic::Expression>(type));
+}
+
} // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index c0532a0..663b588 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -21,6 +21,7 @@
#include "src/ast/module.h"
#include "src/diagnostic/diagnostic.h"
+#include "src/program_builder.h"
#include "src/scope_stack.h"
#include "src/type/storage_texture_type.h"
@@ -137,6 +138,18 @@
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
bool DetermineUnaryOp(ast::UnaryOpExpression* expr);
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ type::Type* TypeOf(ast::Expression* expr) const {
+ return builder_->TypeOf(expr);
+ }
+
+ /// Creates a semantic::Expression node with the resolved type `type`, and
+ /// assigns this semantic node to the expression `expr`.
+ /// @param expr the expression
+ /// @param type the resolved type
+ void SetType(ast::Expression* expr, type::Type* type) const;
+
ProgramBuilder* builder_;
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index b9548a9..29daee7 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -51,6 +51,7 @@
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
#include "src/type/bool_type.h"
@@ -132,11 +133,11 @@
auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
EXPECT_TRUE(td()->DetermineResultType(assign));
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Case) {
@@ -151,10 +152,10 @@
auto* cse = create<ast::CaseStatement>(lit, body);
EXPECT_TRUE(td()->DetermineResultType(cse));
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Block) {
@@ -166,10 +167,10 @@
});
EXPECT_TRUE(td()->DetermineResultType(block));
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Else) {
@@ -182,12 +183,12 @@
auto* stmt = create<ast::ElseStatement>(Expr(3), body);
EXPECT_TRUE(td()->DetermineResultType(stmt));
- ASSERT_NE(stmt->condition()->result_type(), nullptr);
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
- EXPECT_TRUE(stmt->condition()->result_type()->Is<type::I32>());
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(stmt->condition()), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_If) {
@@ -210,16 +211,16 @@
ast::ElseStatementList{else_stmt});
EXPECT_TRUE(td()->DetermineResultType(stmt));
- ASSERT_NE(stmt->condition()->result_type(), nullptr);
- ASSERT_NE(else_lhs->result_type(), nullptr);
- ASSERT_NE(else_rhs->result_type(), nullptr);
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
- EXPECT_TRUE(stmt->condition()->result_type()->Is<type::I32>());
- EXPECT_TRUE(else_lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(else_rhs->result_type()->Is<type::F32>());
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(stmt->condition()), nullptr);
+ ASSERT_NE(TypeOf(else_lhs), nullptr);
+ ASSERT_NE(TypeOf(else_rhs), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
+ EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(else_lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(else_rhs)->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Loop) {
@@ -240,14 +241,14 @@
auto* stmt = create<ast::LoopStatement>(body, continuing);
EXPECT_TRUE(td()->DetermineResultType(stmt));
- ASSERT_NE(body_lhs->result_type(), nullptr);
- ASSERT_NE(body_rhs->result_type(), nullptr);
- ASSERT_NE(continuing_lhs->result_type(), nullptr);
- ASSERT_NE(continuing_rhs->result_type(), nullptr);
- EXPECT_TRUE(body_lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(body_rhs->result_type()->Is<type::F32>());
- EXPECT_TRUE(continuing_lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(continuing_rhs->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(body_lhs), nullptr);
+ ASSERT_NE(TypeOf(body_rhs), nullptr);
+ ASSERT_NE(TypeOf(continuing_lhs), nullptr);
+ ASSERT_NE(TypeOf(continuing_rhs), nullptr);
+ EXPECT_TRUE(TypeOf(body_lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(body_rhs)->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(continuing_lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(continuing_rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Return) {
@@ -256,8 +257,8 @@
auto* ret = create<ast::ReturnStatement>(cond);
EXPECT_TRUE(td()->DetermineResultType(ret));
- ASSERT_NE(cond->result_type(), nullptr);
- EXPECT_TRUE(cond->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(cond), nullptr);
+ EXPECT_TRUE(TypeOf(cond)->Is<type::I32>());
}
TEST_F(TypeDeterminerTest, Stmt_Return_WithoutValue) {
@@ -281,13 +282,13 @@
auto* stmt = create<ast::SwitchStatement>(Expr(2), cases);
EXPECT_TRUE(td()->DetermineResultType(stmt)) << td()->error();
- ASSERT_NE(stmt->condition()->result_type(), nullptr);
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(stmt->condition()), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
- EXPECT_TRUE(stmt->condition()->result_type()->Is<type::I32>());
- EXPECT_TRUE(lhs->result_type()->Is<type::I32>());
- EXPECT_TRUE(rhs->result_type()->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Call) {
@@ -303,8 +304,8 @@
auto* call = create<ast::CallStatement>(expr);
EXPECT_TRUE(td()->DetermineResultType(call));
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Stmt_Call_undeclared) {
@@ -343,8 +344,8 @@
auto* decl = create<ast::VariableDeclStatement>(var);
EXPECT_TRUE(td()->DetermineResultType(decl));
- ASSERT_NE(init->result_type(), nullptr);
- EXPECT_TRUE(init->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(init), nullptr);
+ EXPECT_TRUE(TypeOf(init)->Is<type::I32>());
}
TEST_F(TypeDeterminerTest, Stmt_VariableDecl_ModuleScope) {
@@ -355,8 +356,8 @@
AST().AddGlobalVariable(var);
EXPECT_TRUE(td()->Determine());
- ASSERT_NE(init->result_type(), nullptr);
- EXPECT_TRUE(init->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(init), nullptr);
+ EXPECT_TRUE(TypeOf(init)->Is<type::I32>());
}
TEST_F(TypeDeterminerTest, Expr_Error_Unknown) {
@@ -375,10 +376,10 @@
auto* acc = IndexAccessor("my_var", idx);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- ASSERT_TRUE(acc->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<type::Pointer>());
- auto* ptr = acc->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(acc)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -391,10 +392,10 @@
auto* acc = IndexAccessor("my_var", 2);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- ASSERT_TRUE(acc->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<type::Pointer>());
- auto* ptr = acc->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(acc)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -406,9 +407,8 @@
auto* acc = IndexAccessor("my_var", 2);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- EXPECT_TRUE(acc->result_type()->Is<type::F32>())
- << acc->result_type()->type_name();
+ ASSERT_NE(TypeOf(acc), nullptr);
+ EXPECT_TRUE(TypeOf(acc)->Is<type::F32>()) << TypeOf(acc)->type_name();
}
TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) {
@@ -419,10 +419,10 @@
auto* acc = IndexAccessor("my_var", 2);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- ASSERT_TRUE(acc->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<type::Pointer>());
- auto* ptr = acc->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(acc)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::Vector>());
EXPECT_EQ(ptr->type()->As<type::Vector>()->size(), 3u);
}
@@ -436,10 +436,10 @@
auto* acc = IndexAccessor(IndexAccessor("my_var", 2), 1);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- ASSERT_TRUE(acc->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<type::Pointer>());
- auto* ptr = acc->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(acc)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -451,10 +451,10 @@
auto* acc = IndexAccessor("my_var", 2);
EXPECT_TRUE(td()->DetermineResultType(acc));
- ASSERT_NE(acc->result_type(), nullptr);
- ASSERT_TRUE(acc->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(acc), nullptr);
+ ASSERT_TRUE(TypeOf(acc)->Is<type::Pointer>());
- auto* ptr = acc->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(acc)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -465,8 +465,8 @@
td()->RegisterVariableForTesting(v);
EXPECT_TRUE(td()->DetermineResultType(bitcast));
- ASSERT_NE(bitcast->result_type(), nullptr);
- EXPECT_TRUE(bitcast->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(bitcast), nullptr);
+ EXPECT_TRUE(TypeOf(bitcast)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Call) {
@@ -480,8 +480,8 @@
auto* call = Call("my_func");
EXPECT_TRUE(td()->DetermineResultType(call));
- ASSERT_NE(call->result_type(), nullptr);
- EXPECT_TRUE(call->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
@@ -497,8 +497,8 @@
auto* call = Call("my_func", param);
EXPECT_TRUE(td()->DetermineResultType(call));
- ASSERT_NE(param->result_type(), nullptr);
- EXPECT_TRUE(param->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(param), nullptr);
+ EXPECT_TRUE(TypeOf(param)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Call_Intrinsic) {
@@ -508,8 +508,8 @@
auto* call = Call("round", 2.4f);
EXPECT_TRUE(td()->DetermineResultType(call));
- ASSERT_NE(call->result_type(), nullptr);
- EXPECT_TRUE(call->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(call), nullptr);
+ EXPECT_TRUE(TypeOf(call)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Cast) {
@@ -519,25 +519,25 @@
td()->RegisterVariableForTesting(v);
EXPECT_TRUE(td()->DetermineResultType(cast));
- ASSERT_NE(cast->result_type(), nullptr);
- EXPECT_TRUE(cast->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(cast), nullptr);
+ EXPECT_TRUE(TypeOf(cast)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) {
auto* s = Expr(1.0f);
EXPECT_TRUE(td()->DetermineResultType(s));
- ASSERT_NE(s->result_type(), nullptr);
- EXPECT_TRUE(s->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(s), nullptr);
+ EXPECT_TRUE(TypeOf(s)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Constructor_Type) {
auto* tc = vec3<f32>(1.0f, 1.0f, 3.0f);
EXPECT_TRUE(td()->DetermineResultType(tc));
- ASSERT_NE(tc->result_type(), nullptr);
- ASSERT_TRUE(tc->result_type()->Is<type::Vector>());
- EXPECT_TRUE(tc->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(tc->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(tc), nullptr);
+ ASSERT_TRUE(TypeOf(tc)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(tc)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(tc)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
@@ -548,10 +548,9 @@
auto* ident = Expr("my_var");
EXPECT_TRUE(td()->DetermineResultType(ident));
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::Pointer>());
- EXPECT_TRUE(
- ident->result_type()->As<type::Pointer>()->type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::Pointer>());
+ EXPECT_TRUE(TypeOf(ident)->As<type::Pointer>()->type()->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalConstant) {
@@ -561,8 +560,8 @@
auto* ident = Expr("my_var");
EXPECT_TRUE(td()->DetermineResultType(ident));
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
@@ -579,8 +578,8 @@
EXPECT_TRUE(td()->DetermineFunction(f));
- ASSERT_NE(my_var->result_type(), nullptr);
- EXPECT_TRUE(my_var->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(my_var), nullptr);
+ EXPECT_TRUE(TypeOf(my_var)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
@@ -596,10 +595,9 @@
EXPECT_TRUE(td()->DetermineFunction(f));
- ASSERT_NE(my_var->result_type(), nullptr);
- EXPECT_TRUE(my_var->result_type()->Is<type::Pointer>());
- EXPECT_TRUE(
- my_var->result_type()->As<type::Pointer>()->type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(my_var), nullptr);
+ EXPECT_TRUE(TypeOf(my_var)->Is<type::Pointer>());
+ EXPECT_TRUE(TypeOf(my_var)->As<type::Pointer>()->type()->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) {
@@ -617,10 +615,9 @@
EXPECT_TRUE(td()->DetermineFunction(f));
- ASSERT_NE(my_var->result_type(), nullptr);
- EXPECT_TRUE(my_var->result_type()->Is<type::Pointer>());
- EXPECT_TRUE(
- my_var->result_type()->As<type::Pointer>()->type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(my_var), nullptr);
+ EXPECT_TRUE(TypeOf(my_var)->Is<type::Pointer>());
+ EXPECT_TRUE(TypeOf(my_var)->As<type::Pointer>()->type()->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
@@ -633,8 +630,8 @@
auto* ident = Expr("my_func");
EXPECT_TRUE(td()->DetermineResultType(ident));
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Expr_Identifier_Unknown) {
@@ -762,10 +759,10 @@
auto* mem = MemberAccessor("my_struct", "second_member");
EXPECT_TRUE(td()->DetermineResultType(mem));
- ASSERT_NE(mem->result_type(), nullptr);
- ASSERT_TRUE(mem->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<type::Pointer>());
- auto* ptr = mem->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -785,10 +782,10 @@
auto* mem = MemberAccessor("my_struct", "second_member");
EXPECT_TRUE(td()->DetermineResultType(mem));
- ASSERT_NE(mem->result_type(), nullptr);
- ASSERT_TRUE(mem->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<type::Pointer>());
- auto* ptr = mem->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -800,10 +797,10 @@
auto* mem = MemberAccessor("my_vec", "xy");
EXPECT_TRUE(td()->DetermineResultType(mem)) << td()->error();
- ASSERT_NE(mem->result_type(), nullptr);
- ASSERT_TRUE(mem->result_type()->Is<type::Vector>());
- EXPECT_TRUE(mem->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(mem->result_type()->As<type::Vector>()->size(), 2u);
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@@ -814,10 +811,10 @@
auto* mem = MemberAccessor("my_vec", "x");
EXPECT_TRUE(td()->DetermineResultType(mem)) << td()->error();
- ASSERT_NE(mem->result_type(), nullptr);
- ASSERT_TRUE(mem->result_type()->Is<type::Pointer>());
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<type::Pointer>());
- auto* ptr = mem->result_type()->As<type::Pointer>();
+ auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>());
}
@@ -866,10 +863,10 @@
"yx");
EXPECT_TRUE(td()->DetermineResultType(mem)) << td()->error();
- ASSERT_NE(mem->result_type(), nullptr);
- ASSERT_TRUE(mem->result_type()->Is<type::Vector>());
- EXPECT_TRUE(mem->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(mem->result_type()->As<type::Vector>()->size(), 2u);
+ ASSERT_NE(TypeOf(mem), nullptr);
+ ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
}
using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam<ast::BinaryOp>;
@@ -885,8 +882,8 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
}
TEST_P(Expr_Binary_BitwiseTest, Vector) {
@@ -901,10 +898,10 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::I32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Binary_BitwiseTest,
@@ -931,8 +928,8 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::Bool>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
TEST_P(Expr_Binary_LogicalTest, Vector) {
@@ -947,11 +944,10 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::Bool>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Binary_LogicalTest,
@@ -971,8 +967,8 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::Bool>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
TEST_P(Expr_Binary_CompareTest, Vector) {
@@ -987,11 +983,10 @@
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::Bool>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Binary_CompareTest,
@@ -1012,8 +1007,8 @@
auto* expr = Mul("val", "val");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Scalar) {
@@ -1027,10 +1022,10 @@
auto* expr = Mul("vector", "scalar");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Scalar_Vector) {
@@ -1044,10 +1039,10 @@
auto* expr = Mul("scalar", "vector");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Vector) {
@@ -1059,10 +1054,10 @@
auto* expr = Mul("vector", "vector");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Scalar) {
@@ -1076,10 +1071,10 @@
auto* expr = Mul("matrix", "scalar");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Matrix>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
- auto* mat = expr->result_type()->As<type::Matrix>();
+ auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 3u);
EXPECT_EQ(mat->columns(), 2u);
@@ -1096,10 +1091,10 @@
auto* expr = Mul("scalar", "matrix");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Matrix>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
- auto* mat = expr->result_type()->As<type::Matrix>();
+ auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 3u);
EXPECT_EQ(mat->columns(), 2u);
@@ -1116,10 +1111,10 @@
auto* expr = Mul("matrix", "vector");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Matrix) {
@@ -1133,10 +1128,10 @@
auto* expr = Mul("vector", "matrix");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 2u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 2u);
}
TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Matrix) {
@@ -1150,10 +1145,10 @@
auto* expr = Mul("mat3x4", "mat4x3");
ASSERT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Matrix>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
- auto* mat = expr->result_type()->As<type::Matrix>();
+ auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 4u);
EXPECT_EQ(mat->columns(), 4u);
@@ -1172,8 +1167,8 @@
auto* expr = Call(name, "ident");
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::F32>());
}
TEST_P(IntrinsicDerivativeTest, Vector) {
@@ -1188,10 +1183,10 @@
auto* expr = Call(name, "ident");
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 4u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 4u);
}
TEST_P(IntrinsicDerivativeTest, MissingParam) {
@@ -1244,8 +1239,8 @@
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::Bool>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Intrinsic,
@@ -1265,11 +1260,10 @@
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::Bool>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_P(Intrinsic_FloatMethod, Scalar) {
@@ -1284,8 +1278,8 @@
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::Bool>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
TEST_P(Intrinsic_FloatMethod, MissingParam) {
@@ -1406,19 +1400,16 @@
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
if (type == Texture::kF32) {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
} else if (type == Texture::kI32) {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
} else {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::U32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::U32>());
}
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 4u);
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 4u);
}
INSTANTIATE_TEST_SUITE_P(
@@ -1476,19 +1467,16 @@
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- ASSERT_TRUE(expr->result_type()->Is<type::Vector>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
if (type == Texture::kF32) {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
} else if (type == Texture::kI32) {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::I32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
} else {
- EXPECT_TRUE(
- expr->result_type()->As<type::Vector>()->type()->Is<type::U32>());
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::U32>());
}
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 4u);
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 4u);
}
INSTANTIATE_TEST_SUITE_P(
@@ -1509,8 +1497,8 @@
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr));
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Intrinsic_Select) {
@@ -1526,10 +1514,10 @@
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(expr)) << td()->error();
- ASSERT_NE(expr->result_type(), nullptr);
- EXPECT_TRUE(expr->result_type()->Is<type::Vector>());
- EXPECT_EQ(expr->result_type()->As<type::Vector>()->size(), 3u);
- EXPECT_TRUE(expr->result_type()->As<type::Vector>()->type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(expr), nullptr);
+ EXPECT_TRUE(TypeOf(expr)->Is<type::Vector>());
+ EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
+ EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooFewParams) {
@@ -1572,10 +1560,10 @@
auto* der = create<ast::UnaryOpExpression>(op, Expr("ident"));
EXPECT_TRUE(td()->DetermineResultType(der));
- ASSERT_NE(der->result_type(), nullptr);
- ASSERT_TRUE(der->result_type()->Is<type::Vector>());
- EXPECT_TRUE(der->result_type()->As<type::Vector>()->type()->Is<type::F32>());
- EXPECT_EQ(der->result_type()->As<type::Vector>()->size(), 4u);
+ ASSERT_NE(TypeOf(der), nullptr);
+ ASSERT_TRUE(TypeOf(der)->Is<type::Vector>());
+ EXPECT_TRUE(TypeOf(der)->As<type::Vector>()->type()->Is<type::F32>());
+ EXPECT_EQ(TypeOf(der)->As<type::Vector>()->size(), 4u);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
UnaryOpExpressionTest,
@@ -1731,8 +1719,8 @@
auto* call = Call(ident, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_P(ImportData_SingleParamTest, Vector) {
@@ -1742,9 +1730,9 @@
auto* call = Call(ident, vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_SingleParamTest, Error_Integer) {
@@ -1813,8 +1801,8 @@
auto* call = Call(ident, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Float_Vector) {
@@ -1824,9 +1812,9 @@
auto* call = Call(ident, vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Sint_Scalar) {
@@ -1836,8 +1824,8 @@
auto* call = Call(ident, -1);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::I32>());
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Sint_Vector) {
@@ -1855,9 +1843,9 @@
auto* call = Call(ident, params);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_signed_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Uint_Scalar) {
@@ -1870,8 +1858,8 @@
auto* call = Call(ident, params);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::U32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::U32>());
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Uint_Vector) {
@@ -1881,9 +1869,9 @@
auto* call = Call(ident, vec3<u32>(1u, 1u, 3u));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_unsigned_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_SingleParam_FloatOrInt_Test, Error_Bool) {
@@ -1928,8 +1916,8 @@
auto* call = Call(ident, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_F(TypeDeterminerTest, ImportData_Length_FloatVector) {
@@ -1941,8 +1929,8 @@
auto* call = Call(ident, params);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_F(TypeDeterminerTest, ImportData_Length_Error_Integer) {
@@ -1983,8 +1971,8 @@
auto* call = Call(ident, 1.f, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_P(ImportData_TwoParamTest, Vector) {
@@ -1995,9 +1983,9 @@
Call(ident, vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_TwoParamTest, Error_Integer) {
@@ -2076,8 +2064,8 @@
auto* call = Call(ident, 1.f, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_F(TypeDeterminerTest, ImportData_Distance_Vector) {
@@ -2087,8 +2075,8 @@
Call(ident, vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest, ImportData_Distance_Error_Integer) {
@@ -2146,9 +2134,9 @@
Call(ident, vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_F(TypeDeterminerTest, ImportData_Cross_Error_Scalar) {
@@ -2200,8 +2188,8 @@
auto* call = Call(ident, 1.f, 1.f, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_P(ImportData_ThreeParamTest, Vector) {
@@ -2212,9 +2200,9 @@
vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_ThreeParamTest, Error_Integer) {
@@ -2307,8 +2295,8 @@
auto* call = Call(ident, 1.f, 1.f, 1.f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_scalar());
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Float_Vector) {
@@ -2319,9 +2307,9 @@
vec3<f32>(1.0f, 1.0f, 3.0f), vec3<f32>(1.0f, 1.0f, 3.0f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Sint_Scalar) {
@@ -2331,8 +2319,8 @@
auto* call = Call(ident, 1, 1, 1);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::I32>());
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Sint_Vector) {
@@ -2343,9 +2331,9 @@
Call(ident, vec3<i32>(1, 1, 3), vec3<i32>(1, 1, 3), vec3<i32>(1, 1, 3));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_signed_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Uint_Scalar) {
@@ -2355,8 +2343,8 @@
auto* call = Call(ident, 1u, 1u, 1u);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::U32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::U32>());
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Uint_Vector) {
@@ -2367,9 +2355,9 @@
vec3<u32>(1u, 1u, 3u));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_unsigned_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_ThreeParam_FloatOrInt_Test, Error_Bool) {
@@ -2458,8 +2446,8 @@
auto* call = Call(ident, 1);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_integer_scalar());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_integer_scalar());
}
TEST_P(ImportData_Int_SingleParamTest, Vector) {
@@ -2469,9 +2457,9 @@
auto* call = Call(ident, vec3<i32>(1, 1, 3));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_signed_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_Int_SingleParamTest, Error_Float) {
@@ -2521,8 +2509,8 @@
auto* call = Call(ident, 1, 1);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::I32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::I32>());
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Scalar_Unsigned) {
@@ -2532,8 +2520,8 @@
auto* call = Call(ident, 1u, 1u);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::U32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::U32>());
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Scalar_Float) {
@@ -2543,8 +2531,8 @@
auto* call = Call(ident, 1.0f, 1.0f);
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Vector_Signed) {
@@ -2554,9 +2542,9 @@
auto* call = Call(ident, vec3<i32>(1, 1, 3), vec3<i32>(1, 1, 3));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_signed_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_signed_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Vector_Unsigned) {
@@ -2566,9 +2554,9 @@
auto* call = Call(ident, vec3<u32>(1u, 1u, 3u), vec3<u32>(1u, 1u, 3u));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_unsigned_integer_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_unsigned_integer_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Vector_Float) {
@@ -2578,9 +2566,9 @@
auto* call = Call(ident, vec3<f32>(1.f, 1.f, 3.f), vec3<f32>(1.f, 1.f, 3.f));
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->is_float_vector());
- EXPECT_EQ(ident->result_type()->As<type::Vector>()->size(), 3u);
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->is_float_vector());
+ EXPECT_EQ(TypeOf(ident)->As<type::Vector>()->size(), 3u);
}
TEST_P(ImportData_FloatOrInt_TwoParamTest, Error_Bool) {
@@ -2661,8 +2649,8 @@
auto* call = Call(ident, "var");
EXPECT_TRUE(td()->DetermineResultType(call)) << td()->error();
- ASSERT_NE(ident->result_type(), nullptr);
- EXPECT_TRUE(ident->result_type()->Is<type::F32>());
+ ASSERT_NE(TypeOf(ident), nullptr);
+ EXPECT_TRUE(TypeOf(ident)->Is<type::F32>());
}
using ImportData_Matrix_OneParam_Test =
@@ -3132,40 +3120,38 @@
FAIL() << "invalid texture dimensions: " << param.texture_dimension;
case type::TextureDimension::k1d:
case type::TextureDimension::k1dArray:
- EXPECT_EQ(call->result_type()->type_name(), ty.i32()->type_name());
+ EXPECT_EQ(TypeOf(call)->type_name(), ty.i32()->type_name());
break;
case type::TextureDimension::k2d:
case type::TextureDimension::k2dArray:
- EXPECT_EQ(call->result_type()->type_name(),
- ty.vec2<i32>()->type_name());
+ EXPECT_EQ(TypeOf(call)->type_name(), ty.vec2<i32>()->type_name());
break;
case type::TextureDimension::k3d:
case type::TextureDimension::kCube:
case type::TextureDimension::kCubeArray:
- EXPECT_EQ(call->result_type()->type_name(),
- ty.vec3<i32>()->type_name());
+ EXPECT_EQ(TypeOf(call)->type_name(), ty.vec3<i32>()->type_name());
break;
}
} else if (std::string(param.function) == "textureNumLayers") {
- EXPECT_EQ(call->result_type(), ty.i32());
+ EXPECT_EQ(TypeOf(call), ty.i32());
} else if (std::string(param.function) == "textureNumLevels") {
- EXPECT_EQ(call->result_type(), ty.i32());
+ EXPECT_EQ(TypeOf(call), ty.i32());
} else if (std::string(param.function) == "textureNumSamples") {
- EXPECT_EQ(call->result_type(), ty.i32());
+ EXPECT_EQ(TypeOf(call), ty.i32());
} else if (std::string(param.function) == "textureStore") {
- EXPECT_EQ(call->result_type(), ty.void_());
+ EXPECT_EQ(TypeOf(call), ty.void_());
} else {
switch (param.texture_kind) {
case ast::intrinsic::test::TextureKind::kRegular:
case ast::intrinsic::test::TextureKind::kMultisampled:
case ast::intrinsic::test::TextureKind::kStorage: {
auto* datatype = param.resultVectorComponentType(this);
- ASSERT_TRUE(call->result_type()->Is<type::Vector>());
- EXPECT_EQ(call->result_type()->As<type::Vector>()->type(), datatype);
+ ASSERT_TRUE(TypeOf(call)->Is<type::Vector>());
+ EXPECT_EQ(TypeOf(call)->As<type::Vector>()->type(), datatype);
break;
}
case ast::intrinsic::test::TextureKind::kDepth: {
- EXPECT_EQ(call->result_type(), ty.f32());
+ EXPECT_EQ(TypeOf(call), ty.f32());
break;
}
}
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 5227a2f..818ac50 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -30,6 +30,7 @@
#include "src/ast/switch_statement.h"
#include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h"
+#include "src/semantic/expression.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
#include "src/type/i32_type.h"
@@ -236,8 +237,9 @@
type::Type* func_type = current_function_->return_type();
type::Void void_type;
- auto* ret_type =
- ret->has_value() ? ret->value()->result_type()->UnwrapAll() : &void_type;
+ auto* ret_type = ret->has_value()
+ ? program_->Sem().Get(ret->value())->Type()->UnwrapAll()
+ : &void_type;
if (func_type->type_name() != ret_type->type_name()) {
add_error(ret->source(), "v-000y",
@@ -328,7 +330,7 @@
return false;
}
- auto* cond_type = s->condition()->result_type()->UnwrapAll();
+ auto* cond_type = program_->Sem().Get(s->condition())->Type()->UnwrapAll();
if (!cond_type->is_integer_scalar()) {
add_error(s->condition()->source(), "v-0025",
"switch statement selector expression must be of a "
@@ -472,14 +474,14 @@
// Pointers are not storable in WGSL, but the right-hand side must be
// storable. The raw right-hand side might be a pointer value which must be
// loaded (dereferenced) to provide the value to be stored.
- auto* rhs_result_type = rhs->result_type()->UnwrapAll();
+ auto* rhs_result_type = program_->Sem().Get(rhs)->Type()->UnwrapAll();
if (!IsStorable(rhs_result_type)) {
add_error(assign->source(), "v-000x",
"invalid assignment: right-hand-side is not storable: " +
- rhs->result_type()->type_name());
+ program_->Sem().Get(rhs)->Type()->type_name());
return false;
}
- auto* lhs_result_type = lhs->result_type()->UnwrapIfNeeded();
+ auto* lhs_result_type = program_->Sem().Get(lhs)->Type()->UnwrapIfNeeded();
if (auto* lhs_reference_type = As<type::Pointer>(lhs_result_type)) {
auto* lhs_store_type = lhs_reference_type->type()->UnwrapIfNeeded();
if (lhs_store_type != rhs_result_type) {
@@ -497,7 +499,7 @@
add_error(
assign->source(), "v-000x",
"invalid assignment: left-hand-side does not reference storage: " +
- lhs->result_type()->type_name());
+ program_->Sem().Get(lhs)->Type()->type_name());
return false;
}
diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc
index a7bca4c..278040c 100644
--- a/src/validator/validator_test.cc
+++ b/src/validator/validator_test.cc
@@ -130,8 +130,8 @@
Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -153,8 +153,8 @@
Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -178,8 +178,8 @@
RegisterVariable(var_a);
RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -203,8 +203,8 @@
RegisterVariable(var_a);
RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -227,8 +227,8 @@
Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -257,8 +257,8 @@
RegisterVariable(var_a);
RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -286,8 +286,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -313,8 +313,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(block)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -461,8 +461,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -494,8 +494,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -559,8 +559,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
@@ -592,7 +592,6 @@
AST().Functions().Add(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
- EXPECT_TRUE(td()->DetermineFunction(func)) << td()->error();
ValidatorImpl& v = Build();
@@ -622,7 +621,6 @@
AST().Functions().Add(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
- EXPECT_TRUE(td()->DetermineFunction(func)) << td()->error();
ValidatorImpl& v = Build();
@@ -747,8 +745,8 @@
});
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
- ASSERT_NE(lhs->result_type(), nullptr);
- ASSERT_NE(rhs->result_type(), nullptr);
+ ASSERT_NE(TypeOf(lhs), nullptr);
+ ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build();
diff --git a/src/validator/validator_test_helper.h b/src/validator/validator_test_helper.h
index 6e40632..6ee3795 100644
--- a/src/validator/validator_test_helper.h
+++ b/src/validator/validator_test_helper.h
@@ -21,6 +21,7 @@
#include <vector>
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type/void_type.h"
#include "src/type_determiner.h"
#include "src/validator/validator_impl.h"
diff --git a/src/writer/append_vector.cc b/src/writer/append_vector.cc
index 0ac2447..5b62095 100644
--- a/src/writer/append_vector.cc
+++ b/src/writer/append_vector.cc
@@ -18,6 +18,7 @@
#include "src/ast/expression.h"
#include "src/ast/type_constructor_expression.h"
+#include "src/semantic/expression.h"
#include "src/semantic/info.h"
#include "src/type/vector_type.h"
@@ -42,21 +43,18 @@
ast::Expression* scalar) {
uint32_t packed_size;
type::Type* packed_el_ty; // Currently must be f32.
- if (auto* vec = vector->result_type()->As<type::Vector>()) {
+ auto* vector_sem = b->Sem().Get(vector);
+ if (auto* vec = vector_sem->Type()->As<type::Vector>()) {
packed_size = vec->size() + 1;
packed_el_ty = vec->type();
} else {
packed_size = 2;
- packed_el_ty = vector->result_type();
- }
-
- if (!packed_el_ty) {
- return nullptr; // missing type info
+ packed_el_ty = vector_sem->Type();
}
// Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar);
- scalar_cast->set_result_type(packed_el_ty);
+ b->Sem().Add(scalar_cast, b->create<semantic::Expression>(packed_el_ty));
auto* packed_ty = b->create<type::Vector>(packed_el_ty, packed_size);
@@ -68,14 +66,14 @@
} else {
packed.emplace_back(vector);
}
- if (packed_el_ty != scalar->result_type()) {
+ if (packed_el_ty != b->Sem().Get(scalar)->Type()) {
packed.emplace_back(scalar_cast);
} else {
packed.emplace_back(scalar);
}
auto* constructor = b->Construct(packed_ty, std::move(packed));
- constructor->set_result_type(packed_ty);
+ b->Sem().Add(constructor, b->create<semantic::Expression>(packed_ty));
return constructor;
}
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 5e3caee..7ab03db 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -45,6 +45,7 @@
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
#include "src/type/access_control_type.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
@@ -383,8 +384,8 @@
return true;
}
- auto* lhs_type = expr->lhs()->result_type()->UnwrapAll();
- auto* rhs_type = expr->rhs()->result_type()->UnwrapAll();
+ auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
+ auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
// Multiplying by a matrix requires the use of `mul` in order to get the
// type of multiply we desire.
if (expr->op() == ast::BinaryOp::kMultiply &&
@@ -692,7 +693,7 @@
auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed;
auto* texture = params[pidx.texture];
- auto* texture_type = texture->result_type()->UnwrapAll()->As<type::Texture>();
+ auto* texture_type = TypeOf(texture)->UnwrapAll()->As<type::Texture>();
switch (ident->intrinsic()) {
case ast::Intrinsic::kTextureDimensions:
@@ -887,7 +888,7 @@
auto emit_vector_appended_with_i32_zero = [&](tint::ast::Expression* vector) {
auto* i32 = builder_.create<type::I32>();
auto* zero = builder_.Expr(0);
- zero->set_result_type(i32);
+ builder_.Sem().Add(zero, builder_.create<semantic::Expression>(i32));
auto* packed = AppendVector(&builder_, vector, zero);
return EmitExpression(pre, out, packed);
};
@@ -1857,7 +1858,7 @@
}
first = false;
if (auto* mem = expr->As<ast::MemberAccessorExpression>()) {
- auto* res_type = mem->structure()->result_type()->UnwrapAll();
+ auto* res_type = TypeOf(mem->structure())->UnwrapAll();
if (auto* str = res_type->As<type::Struct>()) {
auto* str_type = str->impl();
auto* str_member = str_type->get_member(mem->member()->symbol());
@@ -1895,7 +1896,7 @@
expr = mem->structure();
} else if (auto* ary = expr->As<ast::ArrayAccessorExpression>()) {
- auto* ary_type = ary->array()->result_type()->UnwrapAll();
+ auto* ary_type = TypeOf(ary->array())->UnwrapAll();
out << "(";
if (auto* arr = ary_type->As<type::Array>()) {
@@ -1942,7 +1943,7 @@
std::ostream& out,
ast::Expression* expr,
ast::Expression* rhs) {
- auto* result_type = expr->result_type()->UnwrapAll();
+ auto* result_type = TypeOf(expr)->UnwrapAll();
bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load";
@@ -2058,7 +2059,7 @@
bool GeneratorImpl::is_storage_buffer_access(
ast::MemberAccessorExpression* expr) {
auto* structure = expr->structure();
- auto* data_type = structure->result_type()->UnwrapAll();
+ auto* data_type = TypeOf(structure)->UnwrapAll();
// TODO(dsinclair): Swizzle
//
// If the data is a multi-element swizzle then we will not load the swizzle
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 4a10975..5c87f36 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -390,6 +390,12 @@
std::string current_ep_var_name(VarType type);
std::string get_buffer_name(ast::Expression* expr);
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ type::Type* TypeOf(ast::Expression* expr) const {
+ return builder_.TypeOf(expr);
+ }
+
std::string error_;
size_t indent_ = 0;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index b0298c9..38a626b 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -50,6 +50,7 @@
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program.h"
+#include "src/semantic/expression.h"
#include "src/type/access_control_type.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
@@ -613,7 +614,7 @@
assert(pidx.texture != kNotUsed);
auto* texture_type =
- params[pidx.texture]->result_type()->UnwrapAll()->As<type::Texture>();
+ TypeOf(params[pidx.texture])->UnwrapAll()->As<type::Texture>();
switch (ident->intrinsic()) {
case ast::Intrinsic::kTextureDimensions: {
@@ -658,7 +659,7 @@
get_dim(dims[0]);
out_ << ")";
} else {
- EmitType(expr->result_type(), "");
+ EmitType(TypeOf(expr), "");
out_ << "(";
for (size_t i = 0; i < dims.size(); i++) {
if (i > 0) {
@@ -764,8 +765,7 @@
}
}
if (pidx.ddx != kNotUsed) {
- auto dim = params[pidx.texture]
- ->result_type()
+ auto dim = TypeOf(params[pidx.texture])
->UnwrapPtrIfNeeded()
->As<type::Texture>()
->dim();
@@ -815,6 +815,7 @@
std::string GeneratorImpl::generate_builtin_name(
ast::IdentifierExpression* ident) {
+ auto* type = TypeOf(ident);
std::string out = "metal::";
switch (ident->intrinsic()) {
case ast::Intrinsic::kAcos:
@@ -852,26 +853,23 @@
out += program_->Symbols().NameFor(ident->symbol());
break;
case ast::Intrinsic::kAbs:
- if (ident->result_type()->Is<type::F32>()) {
+ if (type->Is<type::F32>()) {
out += "fabs";
- } else if (ident->result_type()->Is<type::U32>() ||
- ident->result_type()->Is<type::I32>()) {
+ } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
out += "abs";
}
break;
case ast::Intrinsic::kMax:
- if (ident->result_type()->Is<type::F32>()) {
+ if (type->Is<type::F32>()) {
out += "fmax";
- } else if (ident->result_type()->Is<type::U32>() ||
- ident->result_type()->Is<type::I32>()) {
+ } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
out += "max";
}
break;
case ast::Intrinsic::kMin:
- if (ident->result_type()->Is<type::F32>()) {
+ if (type->Is<type::F32>()) {
out += "fmin";
- } else if (ident->result_type()->Is<type::U32>() ||
- ident->result_type()->Is<type::I32>()) {
+ } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
out += "min";
}
break;
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index d3b56d2..0dee30b 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -280,6 +280,12 @@
std::string current_ep_var_name(VarType type);
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ type::Type* TypeOf(ast::Expression* expr) const {
+ return program_->TypeOf(expr);
+ }
+
Namer namer_;
ScopeStack<ast::Variable*> global_variables_;
Symbol current_ep_sym_;
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 577a842..158c56b 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -59,6 +59,7 @@
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program.h"
+#include "src/semantic/expression.h"
#include "src/type/access_control_type.h"
#include "src/type/alias_type.h"
#include "src/type/array_type.h"
@@ -405,7 +406,8 @@
}
// If the thing we're assigning is a pointer then we must load it first.
- rhs_id = GenerateLoadIfNeeded(assign->rhs()->result_type(), rhs_id);
+ auto* type = TypeOf(assign->rhs());
+ rhs_id = GenerateLoadIfNeeded(type, rhs_id);
return GenerateStore(lhs_id, rhs_id);
}
@@ -639,7 +641,8 @@
if (init_id == 0) {
return false;
}
- init_id = GenerateLoadIfNeeded(var->constructor()->result_type(), init_id);
+ auto* type = TypeOf(var->constructor());
+ init_id = GenerateLoadIfNeeded(type, init_id);
}
if (var->is_const()) {
@@ -843,7 +846,8 @@
if (idx_id == 0) {
return 0;
}
- idx_id = GenerateLoadIfNeeded(expr->idx_expr()->result_type(), idx_id);
+ auto* type = TypeOf(expr->idx_expr());
+ idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a pointer we access chain into it. We also access chain
// into an array of non-scalar types.
@@ -851,11 +855,11 @@
(info->source_type->Is<type::Array>() &&
!info->source_type->As<type::Array>()->type()->is_scalar())) {
info->access_chain_indices.push_back(idx_id);
- info->source_type = expr->result_type();
+ info->source_type = TypeOf(expr);
return true;
}
- auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) {
return false;
}
@@ -872,7 +876,7 @@
}
info->source_id = extract_id;
- info->source_type = expr->result_type();
+ info->source_type = TypeOf(expr);
return true;
}
@@ -880,7 +884,8 @@
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) {
auto* data_type =
- expr->structure()->result_type()->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
+ TypeOf(expr->structure())->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
+ auto* expr_type = TypeOf(expr);
// If the data_type is a structure we're accessing a member, if it's a
// vector we're accessing a swizzle.
@@ -908,7 +913,7 @@
return 0;
}
info->access_chain_indices.push_back(idx_id);
- info->source_type = expr->result_type();
+ info->source_type = expr_type;
return true;
}
@@ -934,7 +939,7 @@
}
info->access_chain_indices.push_back(idx_id);
} else {
- auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) {
return 0;
}
@@ -949,7 +954,7 @@
}
info->source_id = extract_id;
- info->source_type = expr->result_type();
+ info->source_type = expr_type;
}
return true;
}
@@ -977,12 +982,12 @@
return false;
}
- info->source_id = GenerateLoadIfNeeded(expr->result_type(), extract_id);
- info->source_type = expr->result_type()->UnwrapPtrIfNeeded();
+ info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
+ info->source_type = expr_type->UnwrapPtrIfNeeded();
info->access_chain_indices.clear();
}
- auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) {
return false;
}
@@ -1009,7 +1014,7 @@
return false;
}
info->source_id = result_id;
- info->source_type = expr->result_type();
+ info->source_type = expr_type;
return true;
}
@@ -1040,13 +1045,13 @@
if (info.source_id == 0) {
return 0;
}
- info.source_type = source->result_type();
+ info.source_type = TypeOf(source);
// If our initial access is into an array of non-scalar types, and that array
// is not a pointer, then we need to load that array into a variable in order
// to access chain into the array.
if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) {
- auto* ary_res_type = array->array()->result_type();
+ auto* ary_res_type = TypeOf(array->array());
if (!ary_res_type->Is<type::Pointer>() &&
(ary_res_type->Is<type::Array>() &&
@@ -1095,7 +1100,7 @@
}
if (!info.access_chain_indices.empty()) {
- auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) {
return 0;
}
@@ -1153,16 +1158,16 @@
if (val_id == 0) {
return 0;
}
- val_id = GenerateLoadIfNeeded(expr->expr()->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id);
- auto type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
spv::Op op = spv::Op::OpNop;
if (expr->op() == ast::UnaryOp::kNegation) {
- if (expr->result_type()->is_float_scalar_or_vector()) {
+ if (TypeOf(expr)->is_float_scalar_or_vector()) {
op = spv::Op::OpFNegate;
} else {
op = spv::Op::OpSNegate;
@@ -1260,7 +1265,7 @@
} else if (auto* str = subtype->As<type::Struct>()) {
subtype = str->impl()->members()[i]->type()->UnwrapAll();
}
- if (subtype != sc->result_type()->UnwrapAll()) {
+ if (subtype != TypeOf(sc)->UnwrapAll()) {
return false;
}
}
@@ -1291,7 +1296,7 @@
if (auto* res_vec = result_type->As<type::Vector>()) {
if (res_vec->type()->is_scalar()) {
- auto* value_type = values[0]->result_type()->UnwrapAll();
+ auto* value_type = TypeOf(values[0])->UnwrapAll();
if (auto* val_vec = value_type->As<type::Vector>()) {
if (val_vec->type()->is_scalar()) {
can_cast_or_copy = res_vec->size() == val_vec->size();
@@ -1324,13 +1329,13 @@
nullptr, e->As<ast::ConstructorExpression>(), is_global_init);
} else {
id = GenerateExpression(e);
- id = GenerateLoadIfNeeded(e->result_type(), id);
+ id = GenerateLoadIfNeeded(TypeOf(e), id);
}
if (id == 0) {
return 0;
}
- auto* value_type = e->result_type()->UnwrapPtrIfNeeded();
+ auto* value_type = TypeOf(e)->UnwrapPtrIfNeeded();
// If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly.
@@ -1445,9 +1450,9 @@
if (val_id == 0) {
return 0;
}
- val_id = GenerateLoadIfNeeded(from_expr->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
- auto* from_type = from_expr->result_type()->UnwrapPtrIfNeeded();
+ auto* from_type = TypeOf(from_expr)->UnwrapPtrIfNeeded();
spv::Op op = spv::Op::OpNop;
if ((from_type->Is<type::I32>() && to_type->Is<type::F32>()) ||
@@ -1557,13 +1562,13 @@
if (lhs_id == 0) {
return false;
}
- lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id);
+ lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs()), lhs_id);
// Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator.
auto original_label_id = current_label_id_;
- auto type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
@@ -1601,7 +1606,7 @@
if (rhs_id == 0) {
return 0;
}
- rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id);
+ rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs()), rhs_id);
// Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block.
@@ -1638,26 +1643,26 @@
if (lhs_id == 0) {
return 0;
}
- lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id);
+ lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs()), lhs_id);
auto rhs_id = GenerateExpression(expr->rhs());
if (rhs_id == 0) {
return 0;
}
- rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id);
+ rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs()), rhs_id);
auto result = result_op();
auto result_id = result.to_i();
- auto type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) {
return 0;
}
// Handle int and float and the vectors of those types. Other types
// should have been rejected by validation.
- auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
- auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded();
+ auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
+ auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
@@ -1806,7 +1811,7 @@
return GenerateIntrinsic(ident, expr);
}
- auto type_id = GenerateTypeIfNeeded(expr->func()->result_type());
+ auto type_id = GenerateTypeIfNeeded(TypeOf(expr->func()));
if (type_id == 0) {
return 0;
}
@@ -1829,7 +1834,7 @@
if (id == 0) {
return 0;
}
- id = GenerateLoadIfNeeded(param->result_type(), id);
+ id = GenerateLoadIfNeeded(TypeOf(param), id);
ops.push_back(Operand::Int(id));
}
@@ -1845,7 +1850,7 @@
auto result = result_op();
auto result_id = result.to_i();
- auto result_type_id = GenerateTypeIfNeeded(call->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(TypeOf(call));
if (result_type_id == 0) {
return 0;
}
@@ -1895,7 +1900,7 @@
}
params.push_back(Operand::Int(struct_id));
- auto* type = accessor->structure()->result_type()->UnwrapAll();
+ auto* type = TypeOf(accessor->structure())->UnwrapAll();
if (!type->Is<type::Struct>()) {
error_ =
"invalid type (" + type->type_name() + ") for runtime array length";
@@ -1948,8 +1953,7 @@
return 0;
}
auto set_id = set_iter->second;
- auto inst_id =
- intrinsic_to_glsl_method(ident->result_type(), ident->intrinsic());
+ auto inst_id = intrinsic_to_glsl_method(TypeOf(ident), ident->intrinsic());
if (inst_id == 0) {
error_ = "unknown method " + builder_.Symbols().NameFor(ident->symbol());
return 0;
@@ -1972,7 +1976,7 @@
if (val_id == 0) {
return false;
}
- val_id = GenerateLoadIfNeeded(p->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(p), val_id);
params.emplace_back(Operand::Int(val_id));
}
@@ -1995,10 +1999,8 @@
auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed;
assert(pidx.texture != kNotUsed);
- auto* texture_type = call->params()[pidx.texture]
- ->result_type()
- ->UnwrapAll()
- ->As<type::Texture>();
+ auto* texture_type =
+ TypeOf(call->params()[pidx.texture])->UnwrapAll()->As<type::Texture>();
auto op = spv::Op::OpNop;
@@ -2008,7 +2010,7 @@
if (val_id == 0) {
return Operand::Int(0);
}
- val_id = GenerateLoadIfNeeded(p->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(p), val_id);
return Operand::Int(val_id);
};
@@ -2076,7 +2078,7 @@
} else {
// Assign post_emission to swizzle the result of the call to
// OpImageQuerySize[Lod].
- auto* element_type = ElementTypeOf(call->result_type());
+ auto* element_type = ElementTypeOf(TypeOf(call));
auto spirv_result = result_op();
auto* spirv_result_type =
builder_.create<type::Vector>(element_type, spirv_result_width);
@@ -2302,7 +2304,7 @@
}
assert(pidx.level != kNotUsed);
auto level = Operand::Int(0);
- if (call->params()[pidx.level]->result_type()->Is<type::I32>()) {
+ if (TypeOf(call->params()[pidx.level])->Is<type::I32>()) {
// Depth textures have i32 parameters for the level, but SPIR-V expects
// F32. Cast.
auto* f32 = builder_.create<type::F32>();
@@ -2417,7 +2419,7 @@
auto result = result_op();
auto result_id = result.to_i();
- auto result_type_id = GenerateTypeIfNeeded(expr->result_type());
+ auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) {
return 0;
}
@@ -2426,11 +2428,11 @@
if (val_id == 0) {
return 0;
}
- val_id = GenerateLoadIfNeeded(expr->expr()->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id);
// Bitcast does not allow same types, just emit a CopyObject
- auto* to_type = expr->result_type()->UnwrapPtrIfNeeded();
- auto* from_type = expr->expr()->result_type()->UnwrapPtrIfNeeded();
+ auto* to_type = TypeOf(expr)->UnwrapPtrIfNeeded();
+ auto* from_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded();
if (to_type->type_name() == from_type->type_name()) {
if (!push_function_inst(
spv::Op::OpCopyObject,
@@ -2457,7 +2459,7 @@
if (cond_id == 0) {
return false;
}
- cond_id = GenerateLoadIfNeeded(cond->result_type(), cond_id);
+ cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
auto merge_block = result_op();
auto merge_block_id = merge_block.to_i();
@@ -2545,7 +2547,7 @@
if (cond_id == 0) {
return false;
}
- cond_id = GenerateLoadIfNeeded(stmt->condition()->result_type(), cond_id);
+ cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition()), cond_id);
auto default_block = result_op();
auto default_block_id = default_block.to_i();
@@ -2641,7 +2643,7 @@
if (val_id == 0) {
return false;
}
- val_id = GenerateLoadIfNeeded(stmt->value()->result_type(), val_id);
+ val_id = GenerateLoadIfNeeded(TypeOf(stmt->value()), val_id);
if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
return false;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 2c73098..ee5f37f 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -488,6 +488,12 @@
/// automatically.
Operand result_op();
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ type::Type* TypeOf(ast::Expression* expr) const {
+ return builder_.TypeOf(expr);
+ }
+
ProgramBuilder builder_;
std::string error_;
uint32_t next_id_ = 1;