[validation] Validates if return statement type matches function return type
This CL checks if the return statement type matches the function return type
Bug: tint 6
Change-Id: I621d67086291c392b68261673a25c0e6caca71ae
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26860
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/validator.cc b/src/validator.cc
index 5af8414..3186d9e 100644
--- a/src/validator.cc
+++ b/src/validator.cc
@@ -13,7 +13,6 @@
// limitations under the License.
#include "src/validator.h"
-
#include "src/validator_impl.h"
namespace tint {
diff --git a/src/validator_function_test.cc b/src/validator_function_test.cc
index b220e45..e23216a 100644
--- a/src/validator_function_test.cc
+++ b/src/validator_function_test.cc
@@ -88,10 +88,7 @@
"12:34: v-0002: function must end with a return statement");
}
-TEST_F(ValidateFunctionTest,
- DISABLED_FunctionTypeMustMatchReturnStatementType_pass) {
- // TODO(sarahM0): remove DISABLED after implementing function type must match
- // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) {
// fn func -> void { return; }
ast::type::VoidType void_type;
ast::VariableList params;
@@ -107,10 +104,7 @@
EXPECT_TRUE(v.Validate(mod())) << v.error();
}
-TEST_F(ValidateFunctionTest,
- DISABLED_FunctionTypeMustMatchReturnStatementType_fail) {
- // TODO(sarahM0): remove DISABLED after implementing function type must match
- // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) {
// fn func -> void { return 2; }
ast::type::VoidType void_type;
ast::type::I32Type i32;
@@ -130,17 +124,13 @@
tint::ValidatorImpl v;
EXPECT_FALSE(v.Validate(mod()));
// TODO(sarahM0): replace 000y with a rule number
- EXPECT_EQ(
- v.error(),
- "12:34: v-000y: function type must match its return statement type");
+ EXPECT_EQ(v.error(),
+ "12:34: v-000y: return statement type must match its function "
+ "return type, returned '__i32', expected '__void'");
}
-TEST_F(ValidateFunctionTest,
- DISABLED_FunctionTypeMustMatchReturnStatementTypeF32_fail) {
- // TODO(sarahM0): remove DISABLED after implementing function type must match
- // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) {
// fn func -> f32 { return 2; }
- ast::type::VoidType void_type;
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::VariableList params;
@@ -158,9 +148,9 @@
tint::ValidatorImpl v;
EXPECT_FALSE(v.Validate(mod()));
// TODO(sarahM0): replace 000y with a rule number
- EXPECT_EQ(
- v.error(),
- "12:34: v-000y: function type must match its return statement type");
+ EXPECT_EQ(v.error(),
+ "12:34: v-000y: return statement type must match its function "
+ "return type, returned '__i32', expected '__f32'");
}
TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
diff --git a/src/validator_impl.cc b/src/validator_impl.cc
index 0bbda6c..672948d 100644
--- a/src/validator_impl.cc
+++ b/src/validator_impl.cc
@@ -14,6 +14,7 @@
#include "src/validator_impl.h"
#include "src/ast/function.h"
+#include "src/ast/type/void_type.h"
#include "src/ast/variable_decl_statement.h"
namespace tint {
@@ -60,10 +61,11 @@
}
function_stack_.set(func_ptr->name(), func_ptr);
-
+ current_function_ = func_ptr;
if (!ValidateFunction(func_ptr)) {
return false;
}
+ current_function_ = nullptr;
}
return true;
}
@@ -87,6 +89,28 @@
return true;
}
+bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
+ // TODO(sarahM0): update this when this issue resolves:
+ // https://github.com/gpuweb/gpuweb/issues/996
+ ast::type::Type* func_type = current_function_->return_type();
+
+ ast::type::VoidType void_type;
+ auto* ret_type = ret->has_value()
+ ? ret->value()->result_type()->UnwrapAliasPtrAlias()
+ : &void_type;
+
+ if (func_type->type_name() != ret_type->type_name()) {
+ set_error(ret->source(),
+ "v-000y: return statement type must match its function return "
+ "type, returned '" +
+ ret_type->type_name() + "', expected '" +
+ func_type->type_name() + "'");
+ return false;
+ }
+
+ return true;
+}
+
bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) {
if (!block) {
return false;
@@ -126,6 +150,9 @@
if (stmt->IsAssign()) {
return ValidateAssign(stmt->AsAssign());
}
+ if (stmt->IsReturn()) {
+ return ValidateReturnStatement(stmt->AsReturn());
+ }
return true;
}
diff --git a/src/validator_impl.h b/src/validator_impl.h
index 8cf7b96..6f0488e 100644
--- a/src/validator_impl.h
+++ b/src/validator_impl.h
@@ -22,8 +22,8 @@
#include "src/ast/expression.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/module.h"
+#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
-#include "src/ast/type/type.h"
#include "src/ast/variable.h"
#include "src/scope_stack.h"
@@ -96,11 +96,16 @@
/// @returns true if no previous decleration with the |decl|'s name
/// exist in the variable stack
bool ValidateDeclStatement(const ast::VariableDeclStatement* decl);
+ /// Validates return statement
+ /// @param ret the return statement to check
+ /// @returns true if function return type matches the return statement type
+ bool ValidateReturnStatement(const ast::ReturnStatement* ret);
private:
std::string error_;
ScopeStack<ast::Variable*> variable_stack_;
ScopeStack<ast::Function*> function_stack_;
+ ast::Function* current_function_ = nullptr;
};
} // namespace tint
diff --git a/src/validator_test.cc b/src/validator_test.cc
index 9f4ada0..aaed618 100644
--- a/src/validator_test.cc
+++ b/src/validator_test.cc
@@ -674,7 +674,7 @@
ast::VariableList params1;
auto func1 =
- std::make_unique<ast::Function>("func1", std::move(params1), &f32);
+ std::make_unique<ast::Function>("func1", std::move(params1), &void_type);
auto body1 = std::make_unique<ast::BlockStatement>();
body1->append(std::make_unique<ast::VariableDeclStatement>(Source{13, 34},
std::move(var1)));