Add type determination for IdentifierExpression.
This CL adds the type determination for identifier expressions.
Namespaced identifiers are not determined yet.
Bug: tint:5
Change-Id: Id8f39ad122cef0349393de4d429a6d971b2a7ce8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18841
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index b5b9ff1..bb25fbf 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -19,6 +19,7 @@
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/regardless_statement.h"
@@ -176,6 +177,9 @@
if (expr->IsConstructor()) {
return DetermineConstructor(expr->AsConstructor());
}
+ if (expr->IsIdentifier()) {
+ return DetermineIdentifier(expr->AsIdentifier());
+ }
error_ = "unknown expression for type determination";
return false;
@@ -190,4 +194,28 @@
return true;
}
+bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
+ if (expr->name().size() > 1) {
+ // TODO(dsinclair): Handle imports
+ error_ = "imports not handled in type determination";
+ return false;
+ }
+
+ auto name = expr->name()[0];
+ ast::Variable* var;
+ if (variable_stack_.get(name, &var)) {
+ expr->set_result_type(var->type());
+ return true;
+ }
+
+ auto iter = name_to_function_.find(name);
+ if (iter != name_to_function_.end()) {
+ expr->set_result_type(iter->second->return_type());
+ return true;
+ }
+
+ error_ = "unknown identifier for type determination";
+ return false;
+}
+
} // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index f8d8388..03ad01d 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -26,6 +26,7 @@
namespace ast {
class ConstructorExpression;
+class IdentifierExpression;
class Function;
class Variable;
@@ -69,7 +70,7 @@
private:
bool DetermineConstructor(ast::ConstructorExpression* expr);
-
+ bool DetermineIdentifier(ast::IdentifierExpression* expr);
Context& ctx_;
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index da8c7df..c4fec3a 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -24,6 +24,7 @@
#include "src/ast/continue_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/float_literal.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/int_literal.h"
#include "src/ast/loop_statement.h"
@@ -406,5 +407,64 @@
EXPECT_EQ(tc.result_type()->AsVector()->size(), 3);
}
+TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) {
+ ast::type::F32Type f32;
+
+ ast::Module m;
+ auto var =
+ std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ EXPECT_TRUE(td()->Determine(&m));
+
+ ast::IdentifierExpression ident("my_var");
+ EXPECT_TRUE(td()->DetermineResultType(&ident));
+ ASSERT_NE(ident.result_type(), nullptr);
+ EXPECT_TRUE(ident.result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
+ ast::type::F32Type f32;
+
+ auto my_var = std::make_unique<ast::IdentifierExpression>("my_var");
+ auto my_var_ptr = my_var.get();
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::VariableDeclStatement>(
+ std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+ &f32)));
+
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::move(my_var),
+ std::make_unique<ast::IdentifierExpression>("my_var")));
+
+ ast::Function f("my_func", {}, &f32);
+ f.set_body(std::move(body));
+
+ EXPECT_TRUE(td()->DetermineFunction(&f));
+
+ ASSERT_NE(my_var_ptr->result_type(), nullptr);
+ EXPECT_TRUE(my_var_ptr->result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
+ ast::type::F32Type f32;
+
+ ast::VariableList params;
+ auto func =
+ std::make_unique<ast::Function>("my_func", std::move(params), &f32);
+ ast::Module m;
+ m.AddFunction(std::move(func));
+
+ // Register the function
+ EXPECT_TRUE(td()->Determine(&m));
+
+ ast::IdentifierExpression ident("my_func");
+ EXPECT_TRUE(td()->DetermineResultType(&ident));
+ ASSERT_NE(ident.result_type(), nullptr);
+ EXPECT_TRUE(ident.result_type()->IsF32());
+}
+
} // namespace
} // namespace tint