blob: f83e639a738bf2c53e6ce3719f6877ffe32ea8f9 [file] [log] [blame]
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/fuzzers/tint_ast_fuzzer/jump_tracker.h"
#include <string>
#include "gtest/gtest.h"
#include "src/tint/ast/block_statement.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/discard_statement.h"
#include "src/tint/ast/for_loop_statement.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/module.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/program.h"
#include "src/tint/reader/wgsl/parser.h"
namespace tint::fuzzers::ast_fuzzer {
namespace {
TEST(JumpTrackerTest, Breaks) {
std::string content = R"(
fn main() {
var x : u32;
for (var i : i32 = 0; i < 100; i++) {
if (i == 40) {
{
break;
}
}
for (var j : i32 = 0; j < 10; j++) {
loop {
if (i > j) {
break;
}
continuing {
i++;
j-=2;
}
}
switch (j) {
case 0: {
if (i == j) {
break;
}
i = i + 1;
continue;
}
default: {
break;
}
}
}
}
}
)";
Source::File file("test.wgsl", content);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
JumpTracker jump_tracker(program);
const auto* outer_loop_body =
program.AST().Functions()[0]->body->statements[1]->As<ast::ForLoopStatement>()->body;
const auto* first_if = outer_loop_body->statements[0]->As<ast::IfStatement>();
const auto* first_if_body = first_if->body;
const auto* block_in_first_if = first_if_body->statements[0]->As<ast::BlockStatement>();
const auto* break_in_first_if = block_in_first_if->statements[0]->As<ast::BreakStatement>();
const auto* innermost_loop_body = outer_loop_body->statements[1]
->As<ast::ForLoopStatement>()
->body->statements[0]
->As<ast::LoopStatement>()
->body;
const auto* innermost_loop_if = innermost_loop_body->statements[0]->As<ast::IfStatement>();
const auto* innermost_loop_if_body = innermost_loop_if->body;
const auto* break_in_innermost_loop =
innermost_loop_if_body->statements[0]->As<ast::BreakStatement>();
std::unordered_set<const ast::Statement*> containing_loop_break = {
outer_loop_body, first_if,
first_if_body, block_in_first_if,
break_in_first_if, innermost_loop_body,
innermost_loop_if, innermost_loop_if_body,
break_in_innermost_loop};
for (auto* node : program.ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>();
if (stmt == nullptr) {
continue;
}
if (containing_loop_break.count(stmt) > 0) {
ASSERT_TRUE(jump_tracker.ContainsBreakForInnermostLoop(*stmt));
} else {
ASSERT_FALSE(jump_tracker.ContainsBreakForInnermostLoop(*stmt));
}
}
}
TEST(JumpTrackerTest, Returns) {
std::string content = R"(
fn main() {
var x : u32;
for (var i : i32 = 0; i < 100; i++) {
if (i == 40) {
{
return;
}
}
for (var j : i32 = 0; j < 10; j++) {
loop {
if (i > j) {
return;
}
continuing {
i++;
j-=2;
}
}
switch (j) {
case 0: {
if (i == j) {
break;
}
i = i + 1;
continue;
}
default: {
return;
}
}
}
}
}
)";
Source::File file("test.wgsl", content);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
JumpTracker jump_tracker(program);
const auto* function_body = program.AST().Functions()[0]->body;
const auto* outer_loop = function_body->statements[1]->As<ast::ForLoopStatement>();
const auto* outer_loop_body = outer_loop->body;
const auto* first_if = outer_loop_body->statements[0]->As<ast::IfStatement>();
const auto* first_if_body = first_if->body;
const auto* block_in_first_if = first_if_body->statements[0]->As<ast::BlockStatement>();
const auto* return_in_first_if = block_in_first_if->statements[0]->As<ast::ReturnStatement>();
const auto* inner_for_loop = outer_loop_body->statements[1]->As<ast::ForLoopStatement>();
const auto* inner_for_loop_body = inner_for_loop->body;
const auto* innermost_loop = inner_for_loop_body->statements[0]->As<ast::LoopStatement>();
const auto* innermost_loop_body = innermost_loop->body;
const auto* innermost_loop_if = innermost_loop_body->statements[0]->As<ast::IfStatement>();
const auto* innermost_loop_if_body = innermost_loop_if->body;
const auto* return_in_innermost_loop =
innermost_loop_if_body->statements[0]->As<ast::ReturnStatement>();
const auto* switch_statement = inner_for_loop_body->statements[1]->As<ast::SwitchStatement>();
const auto* default_statement = switch_statement->body[1];
const auto* default_statement_body = default_statement->body;
const auto* return_in_default_statement =
default_statement_body->statements[0]->As<ast::ReturnStatement>();
std::unordered_set<const ast::Statement*> containing_return = {
function_body, outer_loop,
outer_loop_body, first_if,
first_if_body, block_in_first_if,
return_in_first_if, inner_for_loop,
inner_for_loop_body, innermost_loop,
innermost_loop_body, innermost_loop_if,
innermost_loop_if_body, return_in_innermost_loop,
switch_statement, default_statement,
default_statement_body, return_in_default_statement};
for (auto* node : program.ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>();
if (stmt == nullptr) {
continue;
}
if (containing_return.count(stmt) > 0) {
ASSERT_TRUE(jump_tracker.ContainsReturn(*stmt));
} else {
ASSERT_FALSE(jump_tracker.ContainsReturn(*stmt));
}
}
}
TEST(JumpTrackerTest, Discards) {
std::string content = R"(
fn main() {
var x : u32;
for (var i : i32 = 0; i < 100; i++) {
if (i == 40) {
{
discard;
}
}
for (var j : i32 = 0; j < 10; j++) {
loop {
if (i > j) {
discard;
}
continuing {
i++;
j-=2;
}
}
switch (j) {
case 0: {
if (i == j) {
break;
}
i = i + 1;
continue;
}
default: {
discard;
}
}
}
}
}
)";
Source::File file("test.wgsl", content);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
JumpTracker jump_tracker(program);
const auto* function_body = program.AST().Functions()[0]->body;
const auto* outer_loop = function_body->statements[1]->As<ast::ForLoopStatement>();
const auto* outer_loop_body = outer_loop->body;
const auto* first_if = outer_loop_body->statements[0]->As<ast::IfStatement>();
const auto* first_if_body = first_if->body;
const auto* block_in_first_if = first_if_body->statements[0]->As<ast::BlockStatement>();
const auto* discard_in_first_if = block_in_first_if->statements[0]->As<ast::DiscardStatement>();
const auto* inner_for_loop = outer_loop_body->statements[1]->As<ast::ForLoopStatement>();
const auto* inner_for_loop_body = inner_for_loop->body;
const auto* innermost_loop = inner_for_loop_body->statements[0]->As<ast::LoopStatement>();
const auto* innermost_loop_body = innermost_loop->body;
const auto* innermost_loop_if = innermost_loop_body->statements[0]->As<ast::IfStatement>();
const auto* innermost_loop_if_body = innermost_loop_if->body;
const auto* discard_in_innermost_loop =
innermost_loop_if_body->statements[0]->As<ast::DiscardStatement>();
const auto* switch_statement = inner_for_loop_body->statements[1]->As<ast::SwitchStatement>();
const auto* default_statement = switch_statement->body[1];
const auto* default_statement_body = default_statement->body;
const auto* discard_in_default_statement =
default_statement_body->statements[0]->As<ast::DiscardStatement>();
std::unordered_set<const ast::Statement*> containing_discard = {
function_body, outer_loop,
outer_loop_body, first_if,
first_if_body, block_in_first_if,
discard_in_first_if, inner_for_loop,
inner_for_loop_body, innermost_loop,
innermost_loop_body, innermost_loop_if,
innermost_loop_if_body, discard_in_innermost_loop,
switch_statement, default_statement,
default_statement_body, discard_in_default_statement};
for (auto* node : program.ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>();
if (stmt == nullptr) {
continue;
}
if (containing_discard.count(stmt) > 0) {
ASSERT_TRUE(jump_tracker.ContainsIntraproceduralDiscard(*stmt));
} else {
ASSERT_FALSE(jump_tracker.ContainsIntraproceduralDiscard(*stmt));
}
}
}
TEST(JumpTrackerTest, WhileLoop) {
std::string content = R"(
fn main() {
var x : u32;
x = 0;
while (x < 100) {
if (x > 50) {
break;
}
x = x + 1;
}
}
)";
Source::File file("test.wgsl", content);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
JumpTracker jump_tracker(program);
const auto* while_loop_body =
program.AST().Functions()[0]->body->statements[2]->As<ast::WhileStatement>()->body;
const auto* if_statement = while_loop_body->statements[0]->As<ast::IfStatement>();
const auto* if_statement_body = if_statement->body;
const auto* break_in_if = if_statement_body->statements[0]->As<ast::BreakStatement>();
std::unordered_set<const ast::Statement*> containing_loop_break = {
while_loop_body, if_statement, if_statement_body, break_in_if};
for (auto* node : program.ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>();
if (stmt == nullptr) {
continue;
}
if (containing_loop_break.count(stmt) > 0) {
ASSERT_TRUE(jump_tracker.ContainsBreakForInnermostLoop(*stmt));
} else {
ASSERT_FALSE(jump_tracker.ContainsBreakForInnermostLoop(*stmt));
}
}
}
} // namespace
} // namespace tint::fuzzers::ast_fuzzer