[validation] Checks if recursions exist This CL validates the following rule. ie. As functions must be defined before use (v-0005), self-recursion is only case that has to be invalidated. v-0004: Recursions are not allowed. Bug: tint: 6 Change-Id: Icfb040907c5ea0abb6359dade74dcfc30a0db7d9 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26980 Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/validator_function_test.cc b/src/validator_function_test.cc index e23216a..fb484c7 100644 --- a/src/validator_function_test.cc +++ b/src/validator_function_test.cc
@@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "spirv/unified1/GLSL.std.450.h" +#include "src/ast/call_statement.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" @@ -188,5 +189,56 @@ EXPECT_EQ(v.error(), "12:34: v-0016: function names must be unique 'func'"); } +TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) { + // fn func() -> void {func(); return; } + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::ExpressionList call_params; + auto call_expr = std::make_unique<ast::CallExpression>( + Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"), + std::move(call_params)); + ast::VariableList params0; + auto func0 = + std::make_unique<ast::Function>("func", std::move(params0), &f32); + auto body0 = std::make_unique<ast::BlockStatement>(); + body0->append(std::make_unique<ast::CallStatement>(std::move(call_expr))); + body0->append(std::make_unique<ast::ReturnStatement>()); + func0->set_body(std::move(body0)); + mod()->AddFunction(std::move(func0)); + + EXPECT_TRUE(td()->Determine()) << td()->error(); + tint::ValidatorImpl v; + EXPECT_FALSE(v.Validate(mod())) << v.error(); + EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'"); +} + +TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) { + // fn func() -> i32 {var a: i32 = func(); return 2; } + ast::type::I32Type i32; + auto var = + std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &i32); + ast::ExpressionList call_params; + auto call_expr = std::make_unique<ast::CallExpression>( + Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"), + std::move(call_params)); + var->set_constructor(std::move(call_expr)); + ast::VariableList params0; + auto func0 = + std::make_unique<ast::Function>("func", std::move(params0), &i32); + auto body0 = std::make_unique<ast::BlockStatement>(); + body0->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + auto return_expr = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::SintLiteral>(&i32, 2)); + + body0->append(std::make_unique<ast::ReturnStatement>(std::move(return_expr))); + func0->set_body(std::move(body0)); + mod()->AddFunction(std::move(func0)); + + EXPECT_TRUE(td()->Determine()) << td()->error(); + tint::ValidatorImpl v; + EXPECT_FALSE(v.Validate(mod())) << v.error(); + EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'"); +} + } // namespace } // namespace tint
diff --git a/src/validator_impl.cc b/src/validator_impl.cc index 672948d..d978cc7 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc
@@ -13,7 +13,9 @@ // limitations under the License. #include "src/validator_impl.h" +#include "src/ast/call_statement.h" #include "src/ast/function.h" +#include "src/ast/intrinsic.h" #include "src/ast/type/void_type.h" #include "src/ast/variable_decl_statement.h" @@ -145,7 +147,13 @@ return false; } if (stmt->IsVariableDecl()) { - return ValidateDeclStatement(stmt->AsVariableDecl()); + auto* v = stmt->AsVariableDecl(); + bool constructor_valid = + v->variable()->has_constructor() + ? ValidateExpression(v->variable()->constructor()) + : true; + + return constructor_valid && ValidateDeclStatement(stmt->AsVariableDecl()); } if (stmt->IsAssign()) { return ValidateAssign(stmt->AsAssign()); @@ -153,6 +161,44 @@ if (stmt->IsReturn()) { return ValidateReturnStatement(stmt->AsReturn()); } + if (stmt->IsCall()) { + return ValidateCallExpr(stmt->AsCall()->expr()); + } + return true; +} + +bool ValidatorImpl::ValidateCallExpr(const ast::CallExpression* expr) { + if (!expr) { + // TODO(sarahM0): Here and other Validate.*: figure out whether return false + // or true + return false; + } + + if (expr->func()->IsIdentifier()) { + auto* ident = expr->func()->AsIdentifier(); + auto func_name = ident->name(); + if (ident->has_path()) { + // TODO(sarahM0): validate import statements + } else if (ast::intrinsic::IsIntrinsic(ident->name())) { + // TODO(sarahM0): validate intrinsics - tied with type-determiner + } else { + if (!function_stack_.has(func_name)) { + set_error(expr->source(), + "v-0005: function must be declared before use: '" + + func_name + "'"); + return false; + } + if (func_name == current_function_->name()) { + set_error(expr->source(), + "v-0004: recursion is not allowed: '" + func_name + "'"); + return false; + } + } + } else { + set_error(expr->source(), "Invalid function call expression"); + return false; + } + return true; } @@ -218,6 +264,9 @@ return ValidateIdentifier(expr->AsIdentifier()); } + if (expr->IsCall()) { + return ValidateCallExpr(expr->AsCall()); + } return true; }
diff --git a/src/validator_impl.h b/src/validator_impl.h index 6f0488e..5ed21cf 100644 --- a/src/validator_impl.h +++ b/src/validator_impl.h
@@ -19,6 +19,7 @@ #include <unordered_map> #include "src/ast/assignment_statement.h" +#include "src/ast/call_expression.h" #include "src/ast/expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/module.h" @@ -77,7 +78,7 @@ bool CheckImports(const ast::Module* module); /// Validates an expression /// @param expr the expression to check - /// @return true if the expresssion is valid + /// @return true if the expression is valid bool ValidateExpression(const ast::Expression* expr); /// Validates v-0006:Variables must be defined before use /// @param ident the identifer to check if its in the scope @@ -100,6 +101,10 @@ /// @param ret the return statement to check /// @returns true if function return type matches the return statement type bool ValidateReturnStatement(const ast::ReturnStatement* ret); + /// Validates function calls + /// @param expr the call to validate + /// @returns true if successful + bool ValidateCallExpr(const ast::CallExpression* expr); private: std::string error_;
diff --git a/src/validator_test.cc b/src/validator_test.cc index 853fe6e..1df7170 100644 --- a/src/validator_test.cc +++ b/src/validator_test.cc
@@ -690,27 +690,32 @@ EXPECT_TRUE(v.Validate(mod())) << v.error(); } -TEST_F(ValidatorTest, DISABLED_RecursionIsNotAllowed_Fail) { - // fn func() -> void {func(); return; } - ast::type::F32Type f32; - ast::type::VoidType void_type; - ast::ExpressionList call_params; - auto call_expr = std::make_unique<ast::CallExpression>( - Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"), - std::move(call_params)); - ast::VariableList params0; - auto func0 = - std::make_unique<ast::Function>("func", std::move(params0), &f32); - auto body0 = std::make_unique<ast::BlockStatement>(); - body0->append(std::make_unique<ast::CallStatement>(std::move(call_expr))); - body0->append(std::make_unique<ast::ReturnStatement>()); - func0->set_body(std::move(body0)); - mod()->AddFunction(std::move(func0)); +TEST_F(ValidatorTest, VariableDeclNoConstructor_Pass) { + // { + // var a :i32; + // a = 2; + // } + ast::type::I32Type i32; + auto var = + std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &i32); - EXPECT_TRUE(td()->Determine()) << td()->error(); + td()->RegisterVariableForTesting(var.get()); + auto lhs = std::make_unique<ast::IdentifierExpression>("a"); + auto* lhs_ptr = lhs.get(); + auto rhs = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::SintLiteral>(&i32, 2)); + auto* rhs_ptr = rhs.get(); + + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::AssignmentStatement>( + Source{12, 34}, std::move(lhs), std::move(rhs))); + + EXPECT_TRUE(td()->DetermineStatements(body.get())) << td()->error(); + ASSERT_NE(lhs_ptr->result_type(), nullptr); + ASSERT_NE(rhs_ptr->result_type(), nullptr); tint::ValidatorImpl v; - EXPECT_FALSE(v.Validate(mod())) << v.error(); - EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'"); + EXPECT_TRUE(v.ValidateStatements(body.get())) << v.error(); } } // namespace