Import Tint changes from Dawn
Changes:
- 4b88dbcf8e14aefafe858c931acb36f3ca2c622e Fixup continue support in while loops. by dan sinclair <dsinclair@chromium.org>
- 57dcd3601ccb9392952f892711b42d0068903b7f Fixup merge test issue by dan sinclair <dsinclair@chromium.org>
- 49d1a2d9502c35ac9a507ffa7d6b7130eb96af07 Add while statement parsing. by dan sinclair <dsinclair@chromium.org>
- 2a8861d20ec425d7391631331089b41bd4b47933 tint: Rework errors around variable declarations by Ben Clayton <bclayton@google.com>
- 418e873ad28c0fed6722ee617ba56018c24d6926 tint: Make sure enable directives go first in ordered_glo... by Zhaoming Jiang <zhaoming.jiang@intel.com>
- 33563dc7d7b3cbec159e004f96b59fd479084635 tint/transform: Move State to anonymous namespace by James Price <jrprice@google.com>
GitOrigin-RevId: 4b88dbcf8e14aefafe858c931acb36f3ca2c622e
Change-Id: Ic4b2bbdd348ad65ef27a6e80954decf0eff2a591
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/94080
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index e7171a9..648cf17 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -336,6 +336,8 @@
"ast/vector.h",
"ast/void.cc",
"ast/void.h",
+ "ast/while_statement.cc",
+ "ast/while_statement.h",
"ast/workgroup_attribute.cc",
"ast/workgroup_attribute.h",
"castable.cc",
@@ -436,6 +438,7 @@
"sem/u32.h",
"sem/vector.h",
"sem/void.h",
+ "sem/while_statement.h",
"source.cc",
"source.h",
"symbol.cc",
@@ -523,6 +526,8 @@
"transform/vectorize_scalar_matrix_constructors.h",
"transform/vertex_pulling.cc",
"transform/vertex_pulling.h",
+ "transform/while_to_loop.cc",
+ "transform/while_to_loop.h",
"transform/wrap_arrays_in_structs.cc",
"transform/wrap_arrays_in_structs.h",
"transform/zero_init_workgroup_memory.cc",
@@ -666,6 +671,8 @@
"sem/vector.h",
"sem/void.cc",
"sem/void.h",
+ "sem/while_statement.cc",
+ "sem/while_statement.h",
]
public_deps = [ ":libtint_core_all_src" ]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 6ffc36a..a3767b5 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -223,6 +223,8 @@
ast/vector.h
ast/void.cc
ast/void.h
+ ast/while_statement.cc
+ ast/while_statement.h
ast/workgroup_attribute.cc
ast/workgroup_attribute.h
castable.cc
@@ -365,6 +367,8 @@
sem/vector.h
sem/void.cc
sem/void.h
+ sem/while_statement.cc
+ sem/while_statement.h
symbol_table.cc
symbol_table.h
symbol.cc
@@ -450,6 +454,8 @@
transform/vectorize_scalar_matrix_constructors.h
transform/vertex_pulling.cc
transform/vertex_pulling.h
+ transform/while_to_loop.cc
+ transform/while_to_loop.h
transform/wrap_arrays_in_structs.cc
transform/wrap_arrays_in_structs.h
transform/zero_init_workgroup_memory.cc
@@ -743,6 +749,7 @@
ast/variable_decl_statement_test.cc
ast/variable_test.cc
ast/vector_test.cc
+ ast/while_statement_test.cc
ast/workgroup_attribute_test.cc
castable_test.cc
clone_context_test.cc
@@ -987,6 +994,7 @@
reader/wgsl/parser_impl_variable_ident_decl_test.cc
reader/wgsl/parser_impl_variable_stmt_test.cc
reader/wgsl/parser_impl_variable_qualifier_test.cc
+ reader/wgsl/parser_impl_while_stmt_test.cc
reader/wgsl/token_test.cc
)
endif()
@@ -1102,6 +1110,7 @@
transform/var_for_dynamic_index_test.cc
transform/vectorize_scalar_matrix_constructors_test.cc
transform/vertex_pulling_test.cc
+ transform/while_to_loop_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc
transform/utils/get_insertion_point_test.cc
diff --git a/src/tint/ast/while_statement.cc b/src/tint/ast/while_statement.cc
new file mode 100644
index 0000000..3666baf
--- /dev/null
+++ b/src/tint/ast/while_statement.cc
@@ -0,0 +1,48 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/while_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::WhileStatement);
+
+namespace tint::ast {
+
+WhileStatement::WhileStatement(ProgramID pid,
+ const Source& src,
+ const Expression* cond,
+ const BlockStatement* b)
+ : Base(pid, src), condition(cond), body(b) {
+ TINT_ASSERT(AST, cond);
+ TINT_ASSERT(AST, body);
+
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
+}
+
+WhileStatement::WhileStatement(WhileStatement&&) = default;
+
+WhileStatement::~WhileStatement() = default;
+
+const WhileStatement* WhileStatement::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+
+ auto* cond = ctx->Clone(condition);
+ auto* b = ctx->Clone(body);
+ return ctx->dst->create<WhileStatement>(src, cond, b);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/ast/while_statement.h b/src/tint/ast/while_statement.h
new file mode 100644
index 0000000..9a7a6b0
--- /dev/null
+++ b/src/tint/ast/while_statement.h
@@ -0,0 +1,55 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_AST_WHILE_STATEMENT_H_
+#define SRC_TINT_AST_WHILE_STATEMENT_H_
+
+#include "src/tint/ast/block_statement.h"
+
+namespace tint::ast {
+
+class Expression;
+
+/// A while loop statement
+class WhileStatement final : public Castable<WhileStatement, Statement> {
+ public:
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param source the for loop statement source
+ /// @param condition the optional loop condition expression
+ /// @param body the loop body
+ WhileStatement(ProgramID program_id,
+ Source const& source,
+ const Expression* condition,
+ const BlockStatement* body);
+ /// Move constructor
+ WhileStatement(WhileStatement&&);
+ ~WhileStatement() override;
+
+ /// Clones this node and all transitive child nodes using the `CloneContext`
+ /// `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const WhileStatement* Clone(CloneContext* ctx) const override;
+
+ /// The condition expression
+ const Expression* const condition;
+
+ /// The loop body block
+ const BlockStatement* const body;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_AST_WHILE_STATEMENT_H_
diff --git a/src/tint/ast/while_statement_test.cc b/src/tint/ast/while_statement_test.cc
new file mode 100644
index 0000000..73c9e56
--- /dev/null
+++ b/src/tint/ast/while_statement_test.cc
@@ -0,0 +1,85 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/ast/binary_expression.h"
+#include "src/tint/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using WhileStatementTest = TestHelper;
+
+TEST_F(WhileStatementTest, Creation) {
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* body = Block(Return());
+ auto* l = While(cond, body);
+
+ EXPECT_EQ(l->condition, cond);
+ EXPECT_EQ(l->body, body);
+}
+
+TEST_F(WhileStatementTest, Creation_WithSource) {
+ auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
+ auto* body = Block(Return());
+ auto* l = While(Source{{20u, 2u}}, cond, body);
+ auto src = l->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(WhileStatementTest, Assert_Null_Cond) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ auto* body = b.Block();
+ b.While(nullptr, body);
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_Null_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr("i"), b.Expr(5_u));
+ b.While(cond, nullptr);
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.While(b2.Expr(true), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(WhileStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.While(b1.Expr(true), b2.Block());
+ },
+ "internal compiler error");
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 960f3d7..5f17d1d 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -76,6 +76,7 @@
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h"
#include "src/tint/ast/void.h"
+#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
@@ -2339,6 +2340,27 @@
return create<ast::ForLoopStatement>(init, Expr(std::forward<COND>(cond)), cont, body);
}
+ /// Creates a ast::WhileStatement with input body and condition.
+ /// @param source the source information
+ /// @param cond the loop condition
+ /// @param body the loop body
+ /// @returns the while statement pointer
+ template <typename COND>
+ const ast::WhileStatement* While(const Source& source,
+ COND&& cond,
+ const ast::BlockStatement* body) {
+ return create<ast::WhileStatement>(source, Expr(std::forward<COND>(cond)), body);
+ }
+
+ /// Creates a ast::WhileStatement with given condition and body.
+ /// @param cond the condition
+ /// @param body the loop body
+ /// @returns the while loop statement pointer
+ template <typename COND>
+ const ast::WhileStatement* While(COND&& cond, const ast::BlockStatement* body) {
+ return create<ast::WhileStatement>(Expr(std::forward<COND>(cond)), body);
+ }
+
/// Creates a ast::VariableDeclStatement for the input variable
/// @param source the source information
/// @param var the variable to wrap in a decl statement
diff --git a/src/tint/reader/wgsl/lexer.cc b/src/tint/reader/wgsl/lexer.cc
index 7943894..e2b016b 100644
--- a/src/tint/reader/wgsl/lexer.cc
+++ b/src/tint/reader/wgsl/lexer.cc
@@ -1271,6 +1271,9 @@
if (str == "vec4") {
return {Token::Type::kVec4, source, "vec4"};
}
+ if (str == "while") {
+ return {Token::Type::kWhile, source, "while"};
+ }
if (str == "workgroup") {
return {Token::Type::kWorkgroup, source, "workgroup"};
}
diff --git a/src/tint/reader/wgsl/lexer_test.cc b/src/tint/reader/wgsl/lexer_test.cc
index d8decba..ab4ec7e 100644
--- a/src/tint/reader/wgsl/lexer_test.cc
+++ b/src/tint/reader/wgsl/lexer_test.cc
@@ -990,6 +990,7 @@
TokenData{"vec2", Token::Type::kVec2},
TokenData{"vec3", Token::Type::kVec3},
TokenData{"vec4", Token::Type::kVec4},
+ TokenData{"while", Token::Type::kWhile},
TokenData{"workgroup", Token::Type::kWorkgroup}));
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 1319478..0c853bd 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -433,7 +433,7 @@
}
if (gc.matched) {
- if (!expect("let declaration", Token::Type::kSemicolon)) {
+ if (!expect("'let' declaration", Token::Type::kSemicolon)) {
return Failure::kErrored;
}
@@ -562,7 +562,7 @@
bool is_overridable = false;
const char* use = nullptr;
if (match(Token::Type::kLet)) {
- use = "let declaration";
+ use = "'let' declaration";
} else if (match(Token::Type::kOverride)) {
use = "override declaration";
is_overridable = true;
@@ -575,8 +575,18 @@
return Failure::kErrored;
}
+ bool has_initializer = false;
+ if (is_overridable) {
+ has_initializer = match(Token::Type::kEqual);
+ } else {
+ if (!expect(use, Token::Type::kEqual)) {
+ return Failure::kErrored;
+ }
+ has_initializer = true;
+ }
+
const ast::Expression* initializer = nullptr;
- if (match(Token::Type::kEqual)) {
+ if (has_initializer) {
auto init = expect_const_expr();
if (init.errored) {
return Failure::kErrored;
@@ -1587,6 +1597,7 @@
// | switch_stmt
// | loop_stmt
// | for_stmt
+// | while_stmt
// | non_block_statement
// : return_stmt SEMICOLON
// | func_call_stmt SEMICOLON
@@ -1644,6 +1655,14 @@
return stmt_for.value;
}
+ auto stmt_while = while_stmt();
+ if (stmt_while.errored) {
+ return Failure::kErrored;
+ }
+ if (stmt_while.matched) {
+ return stmt_while.value;
+ }
+
if (peek_is(Token::Type::kBraceLeft)) {
auto body = expect_body_stmt();
if (body.errored) {
@@ -1757,13 +1776,13 @@
// | CONST variable_ident_decl EQUAL logical_or_expression
Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
if (match(Token::Type::kLet)) {
- auto decl = expect_variable_ident_decl("let declaration",
+ auto decl = expect_variable_ident_decl("'let' declaration",
/*allow_inferred = */ true);
if (decl.errored) {
return Failure::kErrored;
}
- if (!expect("let declaration", Token::Type::kEqual)) {
+ if (!expect("'let' declaration", Token::Type::kEqual)) {
return Failure::kErrored;
}
@@ -1772,7 +1791,7 @@
return Failure::kErrored;
}
if (!constructor.matched) {
- return add_error(peek(), "missing constructor for let declaration");
+ return add_error(peek(), "missing constructor for 'let' declaration");
}
auto* var = create<ast::Variable>(decl->source, // source
@@ -1803,7 +1822,7 @@
return Failure::kErrored;
}
if (!constructor_expr.matched) {
- return add_error(peek(), "missing constructor for variable declaration");
+ return add_error(peek(), "missing constructor for 'var' declaration");
}
constructor = constructor_expr.value;
@@ -2181,6 +2200,30 @@
create<ast::BlockStatement>(stmts.value));
}
+// while_statement
+// : WHILE expression compound_statement
+Maybe<const ast::WhileStatement*> ParserImpl::while_stmt() {
+ Source source;
+ if (!match(Token::Type::kWhile, &source)) {
+ return Failure::kNoMatch;
+ }
+
+ auto condition = logical_or_expression();
+ if (condition.errored) {
+ return Failure::kErrored;
+ }
+ if (!condition.matched) {
+ return add_error(peek(), "unable to parse while condition expression");
+ }
+
+ auto body = expect_body_stmt();
+ if (body.errored) {
+ return Failure::kErrored;
+ }
+
+ return create<ast::WhileStatement>(source, condition.value, body.value);
+}
+
// func_call_stmt
// : IDENT argument_expression_list
Maybe<const ast::CallStatement*> ParserImpl::func_call_stmt() {
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 0c95159..3f9b090 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -527,6 +527,9 @@
/// Parses a `for_stmt` grammar element
/// @returns the parsed for loop or nullptr
Maybe<const ast::ForLoopStatement*> for_stmt();
+ /// Parses a `while_stmt` grammar element
+ /// @returns the parsed while loop or nullptr
+ Maybe<const ast::WhileStatement*> while_stmt();
/// Parses a `continuing_stmt` grammar element
/// @returns the parsed statements
Maybe<const ast::BlockStatement*> continuing_stmt();
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index 972e118..fc141c0 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -204,7 +204,7 @@
TEST_F(ParserImplErrorTest, ConstVarStmtInvalid) {
EXPECT("fn f() { let >; }",
- R"(test.wgsl:1:14 error: expected identifier for let declaration
+ R"(test.wgsl:1:14 error: expected identifier for 'let' declaration
fn f() { let >; }
^
)");
@@ -212,7 +212,7 @@
TEST_F(ParserImplErrorTest, ConstVarStmtMissingAssignment) {
EXPECT("fn f() { let a : i32; }",
- R"(test.wgsl:1:21 error: expected '=' for let declaration
+ R"(test.wgsl:1:21 error: expected '=' for 'let' declaration
fn f() { let a : i32; }
^
)");
@@ -220,7 +220,7 @@
TEST_F(ParserImplErrorTest, ConstVarStmtMissingConstructor) {
EXPECT("fn f() { let a : i32 = >; }",
- R"(test.wgsl:1:24 error: missing constructor for let declaration
+ R"(test.wgsl:1:24 error: missing constructor for 'let' declaration
fn f() { let a : i32 = >; }
^
)");
@@ -472,7 +472,7 @@
TEST_F(ParserImplErrorTest, GlobalDeclConstInvalidIdentifier) {
EXPECT("let ^ : i32 = 1;",
- R"(test.wgsl:1:5 error: expected identifier for let declaration
+ R"(test.wgsl:1:5 error: expected identifier for 'let' declaration
let ^ : i32 = 1;
^
)");
@@ -480,7 +480,7 @@
TEST_F(ParserImplErrorTest, GlobalDeclConstMissingSemicolon) {
EXPECT("let i : i32 = 1",
- R"(test.wgsl:1:16 error: expected ';' for let declaration
+ R"(test.wgsl:1:16 error: expected ';' for 'let' declaration
let i : i32 = 1
^
)");
@@ -512,7 +512,7 @@
TEST_F(ParserImplErrorTest, GlobalDeclConstBadConstLiteralSpaceLessThan) {
EXPECT("let i = 1 < 2;",
- R"(test.wgsl:1:11 error: expected ';' for let declaration
+ R"(test.wgsl:1:11 error: expected ';' for 'let' declaration
let i = 1 < 2;
^
)");
@@ -1215,7 +1215,7 @@
TEST_F(ParserImplErrorTest, VarStmtInvalidAssignment) {
EXPECT("fn f() { var a : u32 = >; }",
- R"(test.wgsl:1:24 error: missing constructor for variable declaration
+ R"(test.wgsl:1:24 error: missing constructor for 'var' declaration
fn f() { var a : u32 = >; }
^
)");
diff --git a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
index 26f3298..3ae7a32 100644
--- a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc
@@ -253,7 +253,7 @@
// Test a for loop with an invalid initializer statement.
TEST_F(ForStmtErrorTest, InvalidInitializerAsConstDecl) {
std::string for_str = "for (let x: i32;;) { }";
- std::string error_str = "1:16: expected '=' for let declaration";
+ std::string error_str = "1:16: expected '=' for 'let' declaration";
TestForWithError(for_str, error_str);
}
@@ -304,7 +304,7 @@
// Test a for loop with an invalid body.
TEST_F(ForStmtErrorTest, InvalidBody) {
std::string for_str = "for (;;) { let x: i32; }";
- std::string error_str = "1:22: expected '=' for let declaration";
+ std::string error_str = "1:22: expected '=' for 'let' declaration";
TestForWithError(for_str, error_str);
}
diff --git a/src/tint/reader/wgsl/parser_impl_global_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
index e2c7d7a..59012a3 100644
--- a/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
@@ -61,18 +61,25 @@
EXPECT_EQ(v->symbol, program.Symbols().Get("a"));
}
+TEST_F(ParserImplTest, GlobalDecl_GlobalConstant_MissingInitializer) {
+ auto p = parser("let a : vec2<i32>;");
+ p->global_decl();
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:18: expected '=' for 'let' declaration");
+}
+
TEST_F(ParserImplTest, GlobalDecl_GlobalConstant_Invalid) {
auto p = parser("let a : vec2<i32> 1.0;");
p->global_decl();
ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:19: expected ';' for let declaration");
+ EXPECT_EQ(p->error(), "1:19: expected '=' for 'let' declaration");
}
TEST_F(ParserImplTest, GlobalDecl_GlobalConstant_MissingSemicolon) {
auto p = parser("let a : vec2<i32> = vec2<i32>(1, 2)");
p->global_decl();
ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:36: expected ';' for let declaration");
+ EXPECT_EQ(p->error(), "1:36: expected ';' for 'let' declaration");
}
TEST_F(ParserImplTest, GlobalDecl_TypeAlias) {
diff --git a/src/tint/reader/wgsl/parser_impl_reserved_keyword_test.cc b/src/tint/reader/wgsl/parser_impl_reserved_keyword_test.cc
index e840af7..8da2771 100644
--- a/src/tint/reader/wgsl/parser_impl_reserved_keyword_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_reserved_keyword_test.cc
@@ -103,8 +103,7 @@
"unless",
"using",
"vec",
- "void",
- "while"));
+ "void"));
} // namespace
} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_statement_test.cc b/src/tint/reader/wgsl/parser_impl_statement_test.cc
index 235c748..8190829 100644
--- a/src/tint/reader/wgsl/parser_impl_statement_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_statement_test.cc
@@ -114,7 +114,7 @@
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:14: missing constructor for variable declaration");
+ EXPECT_EQ(p->error(), "1:14: missing constructor for 'var' declaration");
}
TEST_F(ParserImplTest, Statement_Variable_MissingSemicolon) {
diff --git a/src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc
index ee89947..9fdc503 100644
--- a/src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc
@@ -63,7 +63,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:15: missing constructor for variable declaration");
+ EXPECT_EQ(p->error(), "1:15: missing constructor for 'var' declaration");
}
TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit) {
@@ -160,7 +160,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:13: expected '=' for let declaration");
+ EXPECT_EQ(p->error(), "1:13: expected '=' for 'let' declaration");
}
TEST_F(ParserImplTest, VariableStmt_Let_MissingConstructor) {
@@ -170,7 +170,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:14: missing constructor for let declaration");
+ EXPECT_EQ(p->error(), "1:14: missing constructor for 'let' declaration");
}
TEST_F(ParserImplTest, VariableStmt_Let_InvalidConstructor) {
@@ -180,7 +180,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:15: missing constructor for let declaration");
+ EXPECT_EQ(p->error(), "1:15: missing constructor for 'let' declaration");
}
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl_while_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_while_stmt_test.cc
new file mode 100644
index 0000000..45c4990
--- /dev/null
+++ b/src/tint/reader/wgsl/parser_impl_while_stmt_test.cc
@@ -0,0 +1,157 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
+
+#include "src/tint/ast/discard_statement.h"
+
+namespace tint::reader::wgsl {
+namespace {
+
+using WhileStmtTest = ParserImplTest;
+
+// Test an empty while loop.
+TEST_F(WhileStmtTest, Empty) {
+ auto p = parser("while (true) { }");
+ auto wl = p->while_stmt();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_FALSE(wl.errored);
+ ASSERT_TRUE(wl.matched);
+ EXPECT_TRUE(Is<ast::Expression>(wl->condition));
+ EXPECT_TRUE(wl->body->Empty());
+}
+
+// Test a while loop with non-empty body.
+TEST_F(WhileStmtTest, Body) {
+ auto p = parser("while (true) { discard; }");
+ auto wl = p->while_stmt();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_FALSE(wl.errored);
+ ASSERT_TRUE(wl.matched);
+ EXPECT_TRUE(Is<ast::Expression>(wl->condition));
+ ASSERT_EQ(wl->body->statements.size(), 1u);
+ EXPECT_TRUE(wl->body->statements[0]->Is<ast::DiscardStatement>());
+}
+
+// Test a while loop with complex condition.
+TEST_F(WhileStmtTest, ComplexCondition) {
+ auto p = parser("while ((a + 1 - 2) == 3) { }");
+ auto wl = p->while_stmt();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_FALSE(wl.errored);
+ ASSERT_TRUE(wl.matched);
+ EXPECT_TRUE(Is<ast::Expression>(wl->condition));
+ EXPECT_TRUE(wl->body->Empty());
+}
+
+// Test a while loop with no brackets.
+TEST_F(WhileStmtTest, NoBrackets) {
+ auto p = parser("while (a + 1 - 2) == 3 { }");
+ auto wl = p->while_stmt();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_FALSE(wl.errored);
+ ASSERT_TRUE(wl.matched);
+ EXPECT_TRUE(Is<ast::BinaryExpression>(wl->condition));
+ EXPECT_TRUE(wl->body->Empty());
+}
+
+class WhileStmtErrorTest : public ParserImplTest {
+ public:
+ void TestForWithError(std::string for_str, std::string error_str) {
+ auto p_for = parser(for_str);
+ auto e_for = p_for->while_stmt();
+
+ EXPECT_FALSE(e_for.matched);
+ EXPECT_TRUE(e_for.errored);
+ EXPECT_TRUE(p_for->has_error());
+ ASSERT_EQ(e_for.value, nullptr);
+ EXPECT_EQ(p_for->error(), error_str);
+ }
+};
+
+// Test a while loop with missing left parenthesis is invalid.
+TEST_F(WhileStmtErrorTest, MissingLeftParen) {
+ std::string while_str = "while bool) { }";
+ std::string error_str = "1:11: expected '(' for type constructor";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with missing condition is invalid.
+TEST_F(WhileStmtErrorTest, MissingFirstSemicolon) {
+ std::string while_str = "while () {}";
+ std::string error_str = "1:8: unable to parse expression";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with missing right parenthesis is invalid.
+TEST_F(WhileStmtErrorTest, MissingRightParen) {
+ std::string while_str = "while (true {}";
+ std::string error_str = "1:13: expected ')'";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with missing left brace is invalid.
+TEST_F(WhileStmtErrorTest, MissingLeftBrace) {
+ std::string while_str = "while (true) }";
+ std::string error_str = "1:14: expected '{'";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a for loop with missing right brace is invalid.
+TEST_F(WhileStmtErrorTest, MissingRightBrace) {
+ std::string while_str = "while (true) {";
+ std::string error_str = "1:15: expected '}'";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with an invalid break condition.
+TEST_F(WhileStmtErrorTest, InvalidBreakConditionAsExpression) {
+ std::string while_str = "while ((0 == 1) { }";
+ std::string error_str = "1:17: expected ')'";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with a break condition not matching
+// logical_or_expression.
+TEST_F(WhileStmtErrorTest, InvalidBreakConditionMatch) {
+ std::string while_str = "while (var i: i32 = 0) { }";
+ std::string error_str = "1:8: unable to parse expression";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a while loop with an invalid body.
+TEST_F(WhileStmtErrorTest, InvalidBody) {
+ std::string while_str = "while (true) { let x: i32; }";
+ std::string error_str = "1:26: expected '=' for 'let' declaration";
+
+ TestForWithError(while_str, error_str);
+}
+
+// Test a for loop with a body not matching statements
+TEST_F(WhileStmtErrorTest, InvalidBodyMatch) {
+ std::string while_str = "while (true) { fn main() {} }";
+ std::string error_str = "1:16: expected '}'";
+
+ TestForWithError(while_str, error_str);
+}
+
+} // namespace
+} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/token.cc b/src/tint/reader/wgsl/token.cc
index dcd72cf..aec1c9a 100644
--- a/src/tint/reader/wgsl/token.cc
+++ b/src/tint/reader/wgsl/token.cc
@@ -263,6 +263,8 @@
return "vec3";
case Token::Type::kVec4:
return "vec4";
+ case Token::Type::kWhile:
+ return "while";
case Token::Type::kWorkgroup:
return "workgroup";
}
diff --git a/src/tint/reader/wgsl/token.h b/src/tint/reader/wgsl/token.h
index 9587f36..30c2edf 100644
--- a/src/tint/reader/wgsl/token.h
+++ b/src/tint/reader/wgsl/token.h
@@ -274,6 +274,8 @@
kVec3,
/// A 'vec4'
kVec4,
+ /// A 'while'
+ kWhile,
/// A 'workgroup'
kWorkgroup,
};
diff --git a/src/tint/resolver/assignment_validation_test.cc b/src/tint/resolver/assignment_validation_test.cc
index 64363e3..6af636e 100644
--- a/src/tint/resolver/assignment_validation_test.cc
+++ b/src/tint/resolver/assignment_validation_test.cc
@@ -222,7 +222,7 @@
WrapInFunction(var, Assign(Expr(Source{{12, 34}}, "a"), 2_i));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot assign to const\nnote: 'a' is declared here:");
+ EXPECT_EQ(r()->error(), "12:34 error: cannot assign to 'let'\nnote: 'a' is declared here:");
}
TEST_F(ResolverAssignmentValidationTest, AssignNonConstructible_Handle) {
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 29ad152..f7e401b 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -721,7 +721,7 @@
} else {
EXPECT_FALSE(r()->Resolve());
if (!IsBindingAttribute(params.kind)) {
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for variables");
+ EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for module-scope 'var'");
}
}
}
@@ -783,7 +783,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for constants");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: attribute is not valid for module-scope 'let' declaration");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -792,6 +793,53 @@
TestParams{AttributeKind::kBinding, false},
TestParams{AttributeKind::kBuiltin, false},
TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kInterpolate, false},
+ TestParams{AttributeKind::kInvariant, false},
+ TestParams{AttributeKind::kLocation, false},
+ TestParams{AttributeKind::kOffset, false},
+ TestParams{AttributeKind::kSize, false},
+ TestParams{AttributeKind::kStage, false},
+ TestParams{AttributeKind::kStride, false},
+ TestParams{AttributeKind::kWorkgroup, false},
+ TestParams{AttributeKind::kBindingAndGroup, false}));
+
+TEST_F(ConstantAttributeTest, DuplicateAttribute) {
+ GlobalConst("a", ty.f32(), Expr(1.23_f),
+ ast::AttributeList{
+ create<ast::IdAttribute>(Source{{12, 34}}, 0),
+ create<ast::IdAttribute>(Source{{56, 78}}, 1),
+ });
+
+ WrapInFunction();
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(56:78 error: duplicate id attribute
+12:34 note: first attribute declared here)");
+}
+
+using OverrideAttributeTest = TestWithParams;
+TEST_P(OverrideAttributeTest, IsValid) {
+ auto& params = GetParam();
+
+ Override("a", ty.f32(), Expr(1.23_f), createAttributes(Source{{12, 34}}, *this, params.kind));
+
+ WrapInFunction();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for 'override' declaration");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
+ OverrideAttributeTest,
+ testing::Values(TestParams{AttributeKind::kAlign, false},
+ TestParams{AttributeKind::kBinding, false},
+ TestParams{AttributeKind::kBuiltin, false},
+ TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, true},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
@@ -803,7 +851,7 @@
TestParams{AttributeKind::kWorkgroup, false},
TestParams{AttributeKind::kBindingAndGroup, false}));
-TEST_F(ConstantAttributeTest, DuplicateAttribute) {
+TEST_F(OverrideAttributeTest, DuplicateAttribute) {
GlobalConst("a", ty.f32(), Expr(1.23_f),
ast::AttributeList{
create<ast::IdAttribute>(Source{{12, 34}}, 0),
diff --git a/src/tint/resolver/compound_assignment_validation_test.cc b/src/tint/resolver/compound_assignment_validation_test.cc
index 1ae7040..b453c3b 100644
--- a/src/tint/resolver/compound_assignment_validation_test.cc
+++ b/src/tint/resolver/compound_assignment_validation_test.cc
@@ -249,7 +249,7 @@
WrapInFunction(a, CompoundAssign(Expr(Source{{56, 78}}, "a"), 1_i, ast::BinaryOp::kAdd));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(56:78 error: cannot assign to const
+ EXPECT_EQ(r()->error(), R"(56:78 error: cannot assign to 'let'
12:34 note: 'a' is declared here:)");
}
diff --git a/src/tint/resolver/compound_statement_test.cc b/src/tint/resolver/compound_statement_test.cc
index 0444bd3..d962b3e 100644
--- a/src/tint/resolver/compound_statement_test.cc
+++ b/src/tint/resolver/compound_statement_test.cc
@@ -21,6 +21,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/while_statement.h"
using namespace tint::number_suffixes; // NOLINT
@@ -239,6 +240,55 @@
}
}
+TEST_F(ResolverCompoundStatementTest, While) {
+ // fn F() {
+ // while (true) {
+ // return;
+ // }
+ // }
+ auto* cond = Expr(true);
+ auto* stmt = Return();
+ auto* body = Block(stmt);
+ auto* while_ = While(cond, body);
+ auto* f = Func("W", {}, ty.void_(), {while_});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ {
+ auto* s = Sem().Get(while_);
+ ASSERT_NE(s, nullptr);
+ EXPECT_EQ(Sem().Get(body)->Parent(), s);
+ EXPECT_TRUE(s->Is<sem::WhileStatement>());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Parent(), s->Block());
+ }
+ { // Condition expression's statement is the while itself
+ auto* e = Sem().Get(cond);
+ ASSERT_NE(e, nullptr);
+ auto* s = e->Stmt();
+ ASSERT_NE(s, nullptr);
+ ASSERT_TRUE(Is<sem::WhileStatement>(s));
+ ASSERT_NE(s->Parent(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_TRUE(Is<sem::FunctionBlockStatement>(s->Block()));
+ }
+ {
+ auto* s = Sem().Get(stmt);
+ ASSERT_NE(s, nullptr);
+ ASSERT_NE(s->Block(), nullptr);
+ EXPECT_EQ(s->Parent(), s->Block());
+ EXPECT_EQ(s->Block(), s->FindFirstParent<sem::LoopBlockStatement>());
+ EXPECT_TRUE(Is<sem::WhileStatement>(s->Parent()->Parent()));
+ EXPECT_EQ(s->Block()->Parent(), s->FindFirstParent<sem::WhileStatement>());
+ ASSERT_TRUE(Is<sem::FunctionBlockStatement>(s->Block()->Parent()->Parent()));
+ EXPECT_EQ(s->Block()->Parent()->Parent(),
+ s->FindFirstParent<sem::FunctionBlockStatement>());
+ EXPECT_EQ(s->Function()->Declaration(), f);
+ EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr);
+ }
+}
+
TEST_F(ResolverCompoundStatementTest, If) {
// fn F() {
// if (cond_a) {
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 7e66899..cf61149 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -263,6 +263,12 @@
TraverseExpression(v->variable->constructor);
Declare(v->variable->symbol, v->variable);
},
+ [&](const ast::WhileStatement* w) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseExpression(w->condition);
+ TraverseStatement(w->body);
+ },
[&](Default) {
if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) {
@@ -569,8 +575,18 @@
return; // This code assumes there are no undeclared identifiers.
}
- std::unordered_set<const Global*> visited;
+ // Make sure all 'enable' directives go before any other global declarations.
for (auto* global : declaration_order_) {
+ if (auto* enable = global->node->As<ast::Enable>()) {
+ sorted_.add(enable);
+ }
+ }
+
+ for (auto* global : declaration_order_) {
+ if (global->node->Is<ast::Enable>()) {
+ // Skip 'enable' directives here, as they are already added.
+ continue;
+ }
utils::UniqueVector<const Global*> stack;
TraverseDependencies(
global,
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index f562afb..1117a4a 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1080,6 +1080,20 @@
ResolverDependencyGraphOrderedGlobalsTest,
testing::Combine(testing::ValuesIn(kFuncDeclKinds),
testing::ValuesIn(kFuncUseKinds)));
+
+TEST_F(ResolverDependencyGraphOrderedGlobalsTest, EnableFirst) {
+ // Test that enable nodes always go before any other global declaration.
+ // Although all enable directives in a valid WGSL program must go before any other global
+ // declaration, a transform may produce such a AST tree that has some declarations before enable
+ // nodes. DependencyGraph should deal with these cases.
+ auto* var_1 = Global("SYMBOL1", ty.i32(), nullptr);
+ auto* enable_1 = Enable(ast::Extension::kF16);
+ auto* var_2 = Global("SYMBOL2", ty.f32(), nullptr);
+ auto* enable_2 = Enable(ast::Extension::kF16);
+
+ EXPECT_THAT(AST().GlobalDeclarations(), ElementsAre(var_1, enable_1, var_2, enable_2));
+ EXPECT_THAT(Build().ordered_globals, ElementsAre(enable_1, enable_2, var_1, var_2));
+}
} // namespace ordered_globals
////////////////////////////////////////////////////////////////////////////////
@@ -1231,6 +1245,9 @@
Assign(V, V), //
Block( //
Assign(V, V))), //
+ While(Equal(V, V), //
+ Block( //
+ Assign(V, V))), //
Loop(Block(Assign(V, V)), //
Block(Assign(V, V))), //
Switch(V, //
diff --git a/src/tint/resolver/increment_decrement_validation_test.cc b/src/tint/resolver/increment_decrement_validation_test.cc
index e03352e..b8e3aa1 100644
--- a/src/tint/resolver/increment_decrement_validation_test.cc
+++ b/src/tint/resolver/increment_decrement_validation_test.cc
@@ -151,7 +151,7 @@
WrapInFunction(a, Increment(Expr(Source{{56, 78}}, "a")));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify constant value
+ EXPECT_EQ(r()->error(), R"(56:78 error: cannot modify 'let'
12:34 note: 'a' is declared here:)");
}
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
index d8cc649..4d01f7e 100644
--- a/src/tint/resolver/inferred_type_test.cc
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -98,7 +98,7 @@
WrapInFunction();
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: global var declaration must specify a type");
+ EXPECT_EQ(r()->error(), "12:34 error: module-scope 'var' declaration must specify a type");
}
TEST_P(ResolverInferredTypeParamTest, LocalLet_Pass) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index b1f5a95..3d58930 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -49,6 +49,7 @@
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h"
+#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/resolver/uniformity.h"
#include "src/tint/sem/abstract_float.h"
@@ -77,6 +78,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/reverse.h"
@@ -326,19 +328,19 @@
// If the variable has no declared type, infer it from the RHS
if (!storage_ty) {
if (!var->is_const && kind == VariableKind::kGlobal) {
- AddError("global var declaration must specify a type", var->source);
+ AddError("module-scope 'var' declaration must specify a type", var->source);
return nullptr;
}
storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
}
} else if (var->is_const && !var->is_overridable && kind != VariableKind::kParameter) {
- AddError("let declaration must have an initializer", var->source);
+ AddError("'let' declaration must have an initializer", var->source);
return nullptr;
} else if (!var->type) {
AddError((kind == VariableKind::kGlobal)
- ? "module scope var declaration requires a type and initializer"
- : "function scope var declaration requires a type or initializer",
+ ? "module-scope 'var' declaration requires a type or initializer"
+ : "function-scope 'var' declaration requires a type or initializer",
var->source);
return nullptr;
}
@@ -368,7 +370,7 @@
storage_class != ast::StorageClass::kFunction &&
validator_.IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) {
- AddError("function variable has a non-function storage class", var->source);
+ AddError("function-scope 'var' declaration must use 'function' storage class", var->source);
return nullptr;
}
@@ -519,11 +521,13 @@
auto storage_class = sem->StorageClass();
if (!var->is_const && storage_class == ast::StorageClass::kNone) {
- AddError("global variables must have a storage class", var->source);
+ AddError("module-scope 'var' declaration must have a storage class", var->source);
return nullptr;
}
if (var->is_const && storage_class != ast::StorageClass::kNone) {
- AddError("global constants shouldn't have a storage class", var->source);
+ AddError(var->is_overridable ? "'override' declaration must not have a storage class"
+ : "'let' declaration must not have a storage class",
+ var->source);
return nullptr;
}
@@ -852,6 +856,7 @@
[&](const ast::BlockStatement* b) { return BlockStatement(b); },
[&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
[&](const ast::LoopStatement* l) { return LoopStatement(l); },
+ [&](const ast::WhileStatement* w) { return WhileStatement(w); },
[&](const ast::IfStatement* i) { return IfStatement(i); },
[&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
@@ -1037,6 +1042,39 @@
});
}
+sem::WhileStatement* Resolver::WhileStatement(const ast::WhileStatement* stmt) {
+ auto* sem =
+ builder_->create<sem::WhileStatement>(stmt, current_compound_statement_, current_function_);
+ return StatementScope(stmt, sem, [&] {
+ auto& behaviors = sem->Behaviors();
+
+ auto* cond = Expression(stmt->condition);
+ if (!cond) {
+ return false;
+ }
+ sem->SetCondition(cond);
+ behaviors.Add(cond->Behaviors());
+
+ Mark(stmt->body);
+
+ auto* body = builder_->create<sem::LoopBlockStatement>(
+ stmt->body, current_compound_statement_, current_function_);
+ if (!StatementScope(stmt->body, body, [&] { return Statements(stmt->body->statements); })) {
+ return false;
+ }
+
+ behaviors.Add(body->Behaviors());
+ // Always consider the while as having a 'next' behaviour because it has
+ // a condition. We don't check if the condition will terminate but it isn't
+ // valid to have an infinite loop in a WGSL program, so a non-terminating
+ // condition is already an invalid program.
+ behaviors.Add(sem::Behavior::kNext);
+ behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
+
+ return validator_.WhileStatement(sem);
+ });
+}
+
sem::Expression* Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
constexpr size_t kMaxExpressionDepth = 512U;
@@ -2069,21 +2107,16 @@
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
auto* var = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
- if (!var || !var->Declaration()->is_const) {
- AddError("array size identifier must be a module-scope constant", size_source);
- return nullptr;
- }
- if (var->IsOverridable()) {
- AddError("array size expression must not be pipeline-overridable", size_source);
+ if (!var || !var->Declaration()->is_const || var->IsOverridable()) {
+ AddError("array size identifier must be a literal or a module-scope 'let'",
+ size_source);
return nullptr;
}
count_expr = var->Declaration()->constructor;
} else if (!count_expr->Is<ast::LiteralExpression>()) {
- AddError(
- "array size expression must be either a literal or a module-scope "
- "constant",
- size_source);
+ AddError("array size identifier must be a literal or a module-scope 'let'",
+ size_source);
return nullptr;
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 9999651..e619322 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -54,6 +54,7 @@
class SwitchStatement;
class UnaryOpExpression;
class Variable;
+class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@@ -67,6 +68,7 @@
class Statement;
class SwitchStatement;
class TypeConstructor;
+class WhileStatement;
} // namespace tint::sem
namespace tint::resolver {
@@ -233,6 +235,7 @@
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
+ sem::WhileStatement* WhileStatement(const ast::WhileStatement*);
sem::GlobalVariable* GlobalVariable(const ast::Variable*);
sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*);
diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc
index 3a3b14f..b125e39 100644
--- a/src/tint/resolver/resolver_behavior_test.cc
+++ b/src/tint/resolver/resolver_behavior_test.cc
@@ -20,6 +20,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/while_statement.h"
using namespace tint::number_suffixes; // NOLINT
@@ -314,6 +315,56 @@
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
+TEST_F(ResolverBehaviorTest, StmtWhileBreak) {
+ auto* stmt = While(Expr(true), Block(Break()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
+}
+
+TEST_F(ResolverBehaviorTest, StmtWhileDiscard) {
+ auto* stmt = While(Expr(true), Block(Discard()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtWhileReturn) {
+ auto* stmt = While(Expr(true), Block(Return()));
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondTrue) {
+ auto* stmt = While(Expr(true), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
+}
+
+TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondCallFuncMayDiscard) {
+ auto* stmt = While(Equal(Call("DiscardOrNext"), 1_i), Block());
+ WrapInFunction(stmt);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(stmt);
+ EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
+}
+
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
auto* stmt = If(true, Block());
WrapInFunction(stmt);
diff --git a/src/tint/resolver/storage_class_validation_test.cc b/src/tint/resolver/storage_class_validation_test.cc
index 2d75167..a5e7d12 100644
--- a/src/tint/resolver/storage_class_validation_test.cc
+++ b/src/tint/resolver/storage_class_validation_test.cc
@@ -30,7 +30,8 @@
Global(Source{{12, 34}}, "g", ty.f32(), ast::StorageClass::kNone);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: global variables must have a storage class");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: module-scope 'var' declaration must have a storage class");
}
TEST_F(ResolverStorageClassValidationTest, GlobalVariableFunctionStorageClass_Fail) {
@@ -39,8 +40,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: variables declared at module scope must not be in "
- "the function storage class");
+ "12:34 error: module-scope 'var' must not use storage class 'function'");
}
TEST_F(ResolverStorageClassValidationTest, Private_RuntimeArray) {
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index 119c310..27ee04d 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -90,11 +90,21 @@
TEST_F(ResolverTypeValidationTest, GlobalLetWithStorageClass_Fail) {
// let<private> global_var: f32;
AST().AddGlobalVariable(create<ast::Variable>(
- Source{{12, 34}}, Symbols().Register("global_var"), ast::StorageClass::kPrivate,
+ Source{{12, 34}}, Symbols().Register("global_let"), ast::StorageClass::kPrivate,
ast::Access::kUndefined, ty.f32(), true, false, Expr(1.23_f), ast::AttributeList{}));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: global constants shouldn't have a storage class");
+ EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must not have a storage class");
+}
+
+TEST_F(ResolverTypeValidationTest, OverrideWithStorageClass_Fail) {
+ // let<private> global_var: f32;
+ AST().AddGlobalVariable(create<ast::Variable>(
+ Source{{12, 34}}, Symbols().Register("global_override"), ast::StorageClass::kPrivate,
+ ast::Access::kUndefined, ty.f32(), true, true, Expr(1.23_f), ast::AttributeList{}));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: 'override' declaration must not have a storage class");
}
TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) {
@@ -334,7 +344,8 @@
Override("size", nullptr, Expr(10_i));
Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: array size expression must not be pipeline-overridable");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size identifier must be a literal or a module-scope 'let'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) {
@@ -343,7 +354,8 @@
Global("size", ty.i32(), Expr(10_i), ast::StorageClass::kPrivate);
Global("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: array size identifier must be a module-scope constant");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size identifier must be a literal or a module-scope 'let'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_FunctionLet) {
@@ -355,7 +367,8 @@
auto* a = Var("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")));
WrapInFunction(Block(Decl(size), Decl(a)));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: array size identifier must be a module-scope constant");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: array size identifier must be a literal or a module-scope 'let'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_InvalidExpr) {
@@ -364,8 +377,7 @@
WrapInFunction(Block(Decl(a)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: array size expression must be either a literal or a "
- "module-scope constant");
+ "12:34 error: array size identifier must be a literal or a module-scope 'let'");
}
TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) {
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 2bbb487..18fa425 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -35,6 +35,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/block_allocator.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/unique_vector.h"
@@ -491,7 +492,7 @@
// Find the loop or switch statement that we are in.
auto* parent = sem_.Get(b)
->FindFirstParent<sem::SwitchStatement, sem::LoopStatement,
- sem::ForLoopStatement>();
+ sem::ForLoopStatement, sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
@@ -535,8 +536,9 @@
[&](const ast::ContinueStatement* c) {
// Find the loop statement that we are in.
- auto* parent =
- sem_.Get(c)->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement>();
+ auto* parent = sem_.Get(c)
+ ->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement,
+ sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
@@ -638,6 +640,68 @@
}
},
+ [&](const ast::WhileStatement* w) {
+ auto* sem_loop = sem_.Get(w);
+ auto* cfx = CreateNode("loop_start");
+
+ auto* cf_start = cf;
+
+ auto& info = current_function_->loop_switch_infos[sem_loop];
+ info.type = "whileloop";
+
+ // Create input nodes for any variables declared before this loop.
+ for (auto* v : current_function_->local_var_decls) {
+ auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
+ auto* in_node = CreateNode(name + "_value_forloop_in");
+ in_node->AddEdge(current_function_->variables.Get(v));
+ info.var_in_nodes[v] = in_node;
+ current_function_->variables.Set(v, in_node);
+ }
+
+ // Insert the condition at the start of the loop body.
+ {
+ auto [cf_cond, v] = ProcessExpression(cfx, w->condition);
+ auto* cf_condition_end = CreateNode("while_condition_CFend", w);
+ cf_condition_end->affects_control_flow = true;
+ cf_condition_end->AddEdge(v);
+ cf_start = cf_condition_end;
+ }
+
+ // Propagate assignments to the loop exit nodes.
+ for (auto* var : current_function_->local_var_decls) {
+ auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
+ auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
+ return CreateNode(name + "_value_" + info.type + "_exit");
+ });
+ exit_node->AddEdge(current_function_->variables.Get(var));
+ }
+ auto* cf1 = ProcessStatement(cf_start, w->body);
+ cfx->AddEdge(cf1);
+ cfx->AddEdge(cf);
+
+ // Add edges from variable loop input nodes to their values at the end of the loop.
+ for (auto v : info.var_in_nodes) {
+ auto* in_node = v.second;
+ auto* out_node = current_function_->variables.Get(v.first);
+ if (out_node != in_node) {
+ in_node->AddEdge(out_node);
+ }
+ }
+
+ // Set each variable's exit node as its value in the outer scope.
+ for (auto v : info.var_exit_nodes) {
+ current_function_->variables.Set(v.first, v.second);
+ }
+
+ current_function_->loop_switch_infos.erase(sem_loop);
+
+ if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
+ return cf;
+ } else {
+ return cfx;
+ }
+ },
+
[&](const ast::IfStatement* i) {
auto* sem_if = sem_.Get(i);
auto [_, v_cond] = ProcessExpression(cf, i->condition);
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index 7c07357..002f841 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -2311,6 +2311,304 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, While_CallInside_UniformCondition) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read> n : i32;
+
+fn foo() {
+ var i = 0;
+ while (i < n) {
+ workgroupBarrier();
+ i = i + 1;
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, While_CallInside_NonUniformCondition) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> n : i32;
+
+fn foo() {
+ var i = 0;
+ while (i < n) {
+ workgroupBarrier();
+ i = i + 1;
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:7:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:6:3 note: control flow depends on non-uniform value
+ while (i < n) {
+ ^^^^^
+
+test:6:14 note: reading from read_write storage buffer 'n' may result in a non-uniform value
+ while (i < n) {
+ ^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformInLoopAfterBarrier) {
+ // Use a variable for a conditional barrier in a loop, and then assign a non-uniform value to
+ // that variable later in that loop.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (v == 0) {
+ workgroupBarrier();
+ break;
+ }
+
+ v = non_uniform;
+ i++;
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:8:5 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_ConditionalAssignNonUniformWithBreak_BarrierInLoop) {
+ // In a conditional block, assign a non-uniform value and then break, then use a variable for a
+ // conditional barrier later in the loop.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (true) {
+ v = non_uniform;
+ break;
+ }
+ if (v == 0) {
+ workgroupBarrier();
+ }
+ i++;
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, While_ConditionalAssignNonUniformWithBreak_BarrierAfterLoop) {
+ // In a conditional block, assign a non-uniform value and then break, then use a variable for a
+ // conditional barrier after the loop.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (true) {
+ v = non_uniform;
+ break;
+ }
+ v = 5;
+ i++;
+ }
+
+ if (v == 0) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:17:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:16:3 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:9:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_VarRemainsNonUniformAtLoopEnd_BarrierAfterLoop) {
+ // Assign a non-uniform value, assign a uniform value before all explicit break points but leave
+ // the value non-uniform at loop exit, then use a variable for a conditional barrier after the
+ // loop.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (true) {
+ v = 5;
+ break;
+ }
+
+ v = non_uniform;
+
+ if (true) {
+ v = 6;
+ break;
+ }
+ i++;
+ }
+
+ if (v == 0) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:23:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:22:3 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformBeforeConditionalContinue_BarrierAtStart) {
+ // Use a variable for a conditional barrier in a loop, assign a non-uniform value to
+ // that variable later in that loop, then perform a conditional continue before assigning a
+ // uniform value to that variable.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (v == 0) {
+ workgroupBarrier();
+ break;
+ }
+
+ v = non_uniform;
+ if (true) {
+ continue;
+ }
+
+ v = 5;
+ i++;
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:8:5 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformBeforeConditionalContinue) {
+ // Use a variable for a conditional barrier in a loop, assign a non-uniform value to
+ // that variable later in that loop, then perform a conditional continue before assigning a
+ // uniform value to that variable.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var v = 0;
+ var i = 0;
+ while (i < 10) {
+ if (v == 0) {
+ workgroupBarrier();
+ break;
+ }
+
+ v = non_uniform;
+ if (true) {
+ continue;
+ }
+
+ v = 5;
+ i++;
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:8:5 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, While_NonUniformCondition_Reconverge) {
+ // Loops reconverge at exit, so test that we can call workgroupBarrier() after a loop that has a
+ // non-uniform condition.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> n : i32;
+
+fn foo() {
+ var i = 0;
+ while (i < n) {
+ }
+ workgroupBarrier();
+ i = i + 1;
+}
+)";
+
+ RunTest(src, true);
+}
+
} // namespace LoopTest
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index 8cc3a8f..9177d6b 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -304,7 +304,8 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: function variable has a non-function storage class");
+ EXPECT_EQ(r()->error(),
+ "error: function-scope 'var' declaration must use 'function' storage class");
}
TEST_F(ResolverValidationTest, StorageClass_FunctionVariableI32) {
@@ -317,7 +318,8 @@
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: function variable has a non-function storage class");
+ EXPECT_EQ(r()->error(),
+ "error: function-scope 'var' declaration must use 'function' storage class");
}
TEST_F(ResolverValidationTest, StorageClass_SamplerExplicitStorageClass) {
@@ -984,6 +986,26 @@
EXPECT_EQ(r()->error(), "12:34 error: for-loop condition must be bool, got f32");
}
+TEST_F(ResolverTest, Stmt_While_CondIsBoolRef) {
+ // var cond : bool = false;
+ // while (cond) {
+ // }
+
+ auto* cond = Var("cond", ty.bool_(), Expr(false));
+ WrapInFunction(Decl(cond), While("cond", Block()));
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverTest, Stmt_While_CondIsNotBool) {
+ // while (1.0f) {
+ // }
+
+ WrapInFunction(While(Expr(Source{{12, 34}}, 1_f), Block()));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: while condition must be bool, got f32");
+}
+
TEST_F(ResolverValidationTest, Stmt_ContinueInLoop) {
WrapInFunction(Loop(Block(If(false, Block(Break())), //
Continue(Source{{12, 34}}))));
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 4d1ee07..39eb31a 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -72,6 +72,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/math.h"
@@ -237,6 +238,11 @@
break;
}
}
+ if (Is<sem::WhileStatement>(s->Parent())) {
+ if (stop_at_loop) {
+ break;
+ }
+ }
}
return nullptr;
}
@@ -506,25 +512,29 @@
for (auto* attr : decl->attributes) {
if (decl->is_const) {
- if (auto* id_attr = attr->As<ast::IdAttribute>()) {
- uint32_t id = id_attr->value;
- auto it = constant_ids.find(id);
- if (it != constant_ids.end() && it->second != var) {
- AddError("pipeline constant IDs must be unique", attr->source);
- AddNote(
- "a pipeline constant with an ID of " + std::to_string(id) +
- " was previously declared "
- "here:",
- ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
- ->source);
- return false;
- }
- if (id > 65535) {
- AddError("pipeline constant IDs must be between 0 and 65535", attr->source);
+ if (decl->is_overridable) {
+ if (auto* id_attr = attr->As<ast::IdAttribute>()) {
+ uint32_t id = id_attr->value;
+ auto it = constant_ids.find(id);
+ if (it != constant_ids.end() && it->second != var) {
+ AddError("pipeline constant IDs must be unique", attr->source);
+ AddNote("a pipeline constant with an ID of " + std::to_string(id) +
+ " was previously declared here:",
+ ast::GetAttribute<ast::IdAttribute>(
+ it->second->Declaration()->attributes)
+ ->source);
+ return false;
+ }
+ if (id > 65535) {
+ AddError("pipeline constant IDs must be between 0 and 65535", attr->source);
+ return false;
+ }
+ } else {
+ AddError("attribute is not valid for 'override' declaration", attr->source);
return false;
}
} else {
- AddError("attribute is not valid for constants", attr->source);
+ AddError("attribute is not valid for module-scope 'let' declaration", attr->source);
return false;
}
} else {
@@ -536,17 +546,14 @@
if (!(attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
ast::InternalAttribute>()) &&
(!is_shader_io_attribute || !has_io_storage_class)) {
- AddError("attribute is not valid for variables", attr->source);
+ AddError("attribute is not valid for module-scope 'var'", attr->source);
return false;
}
}
}
if (var->StorageClass() == ast::StorageClass::kFunction) {
- AddError(
- "variables declared at module scope must not be in the function "
- "storage class",
- decl->source);
+ AddError("module-scope 'var' must not use storage class 'function'", decl->source);
return false;
}
@@ -559,10 +566,7 @@
// Each resource variable must be declared with both group and binding
// attributes.
if (!binding_point) {
- AddError(
- "resource variables require @group and @binding "
- "attributes",
- decl->source);
+ AddError("resource variables require @group and @binding attributes", decl->source);
return false;
}
break;
@@ -571,10 +575,8 @@
if (binding_point.binding || binding_point.group) {
// https://gpuweb.github.io/gpuweb/wgsl/#attribute-binding
// Must only be applied to a resource variable
- AddError(
- "non-resource variables must not have @group or @binding "
- "attributes",
- decl->source);
+ AddError("non-resource variables must not have @group or @binding attributes",
+ decl->source);
return false;
}
}
@@ -1464,6 +1466,22 @@
return true;
}
+bool Validator::WhileStatement(const sem::WhileStatement* stmt) const {
+ if (stmt->Behaviors().Empty()) {
+ AddError("while does not exit", stmt->Declaration()->source.Begin());
+ return false;
+ }
+ if (auto* cond = stmt->Condition()) {
+ auto* cond_ty = cond->Type()->UnwrapRef();
+ if (!cond_ty->Is<sem::Bool>()) {
+ AddError("while condition must be bool, got " + sem_.TypeNameOf(cond_ty),
+ stmt->Condition()->Declaration()->source);
+ return false;
+ }
+ }
+ return true;
+}
+
bool Validator::IfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
@@ -2162,7 +2180,9 @@
return false;
}
if (decl->is_const) {
- AddError("cannot assign to const", lhs->source);
+ AddError(
+ decl->is_overridable ? "cannot assign to 'override'" : "cannot assign to 'let'",
+ lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false;
}
@@ -2210,7 +2230,8 @@
return false;
}
if (decl->is_const) {
- AddError("cannot modify constant value", lhs->source);
+ AddError(decl->is_overridable ? "cannot modify 'override'" : "cannot modify 'let'",
+ lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false;
}
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index b30fdc7..e3912ba 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -44,6 +44,7 @@
class SwitchStatement;
class UnaryOpExpression;
class Variable;
+class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@@ -58,6 +59,7 @@
class Statement;
class SwitchStatement;
class TypeConstructor;
+class WhileStatement;
} // namespace tint::sem
namespace tint::resolver {
@@ -207,6 +209,11 @@
/// @returns true on success, false otherwise
bool ForLoopStatement(const sem::ForLoopStatement* stmt) const;
+ /// Validates a while loop
+ /// @param stmt the while statement to validate
+ /// @returns true on success, false otherwise
+ bool WhileStatement(const sem::WhileStatement* stmt) const;
+
/// Validates a fallthrough statement
/// @param stmt the fallthrough to validate
/// @returns true on success, false otherwise
diff --git a/src/tint/resolver/var_let_validation_test.cc b/src/tint/resolver/var_let_validation_test.cc
index e1dc343..eea5e72 100644
--- a/src/tint/resolver/var_let_validation_test.cc
+++ b/src/tint/resolver/var_let_validation_test.cc
@@ -29,7 +29,7 @@
WrapInFunction(Let(Source{{12, 34}}, "a", ty.i32(), nullptr));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: let declaration must have an initializer");
+ EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, GlobalLetNoInitializer) {
@@ -37,7 +37,7 @@
GlobalConst(Source{{12, 34}}, "a", ty.i32(), nullptr);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: let declaration must have an initializer");
+ EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) {
@@ -46,8 +46,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: function scope var declaration requires a type or "
- "initializer");
+ "12:34 error: function-scope 'var' declaration requires a type or initializer");
}
TEST_F(ResolverVarLetValidationTest, GlobalVarNoInitializerNoType) {
@@ -56,8 +55,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: module scope var declaration requires a type and "
- "initializer");
+ "12:34 error: module-scope 'var' declaration requires a type or initializer");
}
TEST_F(ResolverVarLetValidationTest, VarTypeNotStorable) {
diff --git a/src/tint/sem/type_mappings.h b/src/tint/sem/type_mappings.h
index 0e54eed..2b082a8 100644
--- a/src/tint/sem/type_mappings.h
+++ b/src/tint/sem/type_mappings.h
@@ -34,6 +34,7 @@
class Type;
class TypeDecl;
class Variable;
+class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@@ -50,6 +51,7 @@
class SwitchStatement;
class Type;
class Variable;
+class WhileStatement;
} // namespace tint::sem
namespace tint::sem {
@@ -74,6 +76,7 @@
Type* operator()(ast::Type*);
Type* operator()(ast::TypeDecl*);
Variable* operator()(ast::Variable*);
+ WhileStatement* operator()(ast::WhileStatement*);
//! @endcond
};
diff --git a/src/tint/sem/while_statement.cc b/src/tint/sem/while_statement.cc
new file mode 100644
index 0000000..495f83f
--- /dev/null
+++ b/src/tint/sem/while_statement.cc
@@ -0,0 +1,34 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/sem/while_statement.h"
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::sem::WhileStatement);
+
+namespace tint::sem {
+
+WhileStatement::WhileStatement(const ast::WhileStatement* declaration,
+ const CompoundStatement* parent,
+ const sem::Function* function)
+ : Base(declaration, parent, function) {}
+
+WhileStatement::~WhileStatement() = default;
+
+const ast::WhileStatement* WhileStatement::Declaration() const {
+ return static_cast<const ast::WhileStatement*>(Base::Declaration());
+}
+
+} // namespace tint::sem
diff --git a/src/tint/sem/while_statement.h b/src/tint/sem/while_statement.h
new file mode 100644
index 0000000..50f1831
--- /dev/null
+++ b/src/tint/sem/while_statement.h
@@ -0,0 +1,60 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0(the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_SEM_WHILE_STATEMENT_H_
+#define SRC_TINT_SEM_WHILE_STATEMENT_H_
+
+#include "src/tint/sem/statement.h"
+
+// Forward declarations
+namespace tint::ast {
+class WhileStatement;
+} // namespace tint::ast
+namespace tint::sem {
+class Expression;
+} // namespace tint::sem
+
+namespace tint::sem {
+
+/// Holds semantic information about a while statement
+class WhileStatement final : public Castable<WhileStatement, CompoundStatement> {
+ public:
+ /// Constructor
+ /// @param declaration the AST node for this while statement
+ /// @param parent the owning statement
+ /// @param function the owning function
+ WhileStatement(const ast::WhileStatement* declaration,
+ const CompoundStatement* parent,
+ const sem::Function* function);
+
+ /// Destructor
+ ~WhileStatement() override;
+
+ /// @returns the AST node
+ const ast::WhileStatement* Declaration() const;
+
+ /// @returns the whilecondition expression
+ const Expression* Condition() const { return condition_; }
+
+ /// Sets the while condition expression
+ /// @param condition the while condition expression
+ void SetCondition(const Expression* condition) { condition_ = condition; }
+
+ private:
+ const Expression* condition_ = nullptr;
+};
+
+} // namespace tint::sem
+
+#endif // SRC_TINT_SEM_WHILE_STATEMENT_H_
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
index 2f775ca..d15d790 100644
--- a/src/tint/transform/expand_compound_assignment.cc
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -44,6 +44,8 @@
return false;
}
+namespace {
+
/// Internal class used to collect statement expansions during the transform.
class State {
private:
@@ -163,6 +165,8 @@
}
};
+} // namespace
+
void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State state(ctx);
for (auto* node : ctx.src->ASTNodes().Objects()) {
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index 6f1cc4c..d527a4c 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -27,6 +27,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/manager.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@@ -383,6 +384,7 @@
ProcessStatement(s->expr);
},
[&](const ast::ForLoopStatement* s) { ProcessStatement(s->condition); },
+ [&](const ast::WhileStatement* s) { ProcessStatement(s->condition); },
[&](const ast::IfStatement* s) { //
ProcessStatement(s->condition);
},
@@ -578,6 +580,15 @@
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
+ [&](const ast::WhileStatement* s) -> const ast::Statement* {
+ if (!sem.Get(s->condition)->HasSideEffects()) {
+ return nullptr;
+ }
+ ast::StatementList stmts;
+ ctx.Replace(s->condition, Decompose(s->condition, &stmts));
+ InsertBefore(stmts, s);
+ return ctx.CloneWithoutTransform(s);
+ },
[&](const ast::IfStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
return nullptr;
diff --git a/src/tint/transform/promote_side_effects_to_decl_test.cc b/src/tint/transform/promote_side_effects_to_decl_test.cc
index 9d9115f..937b59e 100644
--- a/src/tint/transform/promote_side_effects_to_decl_test.cc
+++ b/src/tint/transform/promote_side_effects_to_decl_test.cc
@@ -999,6 +999,45 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InWhileCond) {
+ auto* src = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ while(a(0) + b > 0) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> i32 {
+ return i;
+}
+
+fn f() {
+ var b = 1;
+ loop {
+ let tint_symbol = a(0);
+ if (!(((tint_symbol + b) > 0))) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIf) {
auto* src = R"(
fn a(i : i32) -> i32 {
@@ -2299,6 +2338,48 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InWhileCond) {
+ auto* src = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ while(a(0) && b) {
+ var marker = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn a(i : i32) -> bool {
+ return true;
+}
+
+fn f() {
+ var b = true;
+ loop {
+ var tint_symbol = a(0);
+ if (tint_symbol) {
+ tint_symbol = b;
+ }
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ var marker = 0;
+ }
+ }
+}
+)";
+
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIf) {
auto* src = R"(
fn a(i : i32) -> bool {
diff --git a/src/tint/transform/remove_continue_in_switch.cc b/src/tint/transform/remove_continue_in_switch.cc
index 5c2413e..e5df23f 100644
--- a/src/tint/transform/remove_continue_in_switch.cc
+++ b/src/tint/transform/remove_continue_in_switch.cc
@@ -25,6 +25,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/switch_statement.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/utils/map.h"
@@ -49,7 +50,7 @@
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
- sem::ForLoopStatement>();
+ sem::ForLoopStatement, sem::WhileStatement>();
if (!sem_parent) {
return nullptr;
}
diff --git a/src/tint/transform/remove_continue_in_switch_test.cc b/src/tint/transform/remove_continue_in_switch_test.cc
index a1e7b6e..0b52457 100644
--- a/src/tint/transform/remove_continue_in_switch_test.cc
+++ b/src/tint/transform/remove_continue_in_switch_test.cc
@@ -559,5 +559,59 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(RemoveContinueInSwitchTest, While) {
+ auto* src = R"(
+fn f() {
+ var i = 0;
+ while (i < 4) {
+ let marker1 = 0;
+ switch(i) {
+ case 0: {
+ continue;
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var i = 0;
+ while((i < 4)) {
+ let marker1 = 0;
+ var tint_continue : bool = false;
+ switch(i) {
+ case 0: {
+ {
+ tint_continue = true;
+ break;
+ }
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ if (tint_continue) {
+ continue;
+ }
+ let marker2 = 0;
+ break;
+ }
+}
+)";
+
+ DataMap data;
+ auto got = Run<RemoveContinueInSwitch>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc
index e1ba74c..a7877e1 100644
--- a/src/tint/transform/unwind_discard_functions.cc
+++ b/src/tint/transform/unwind_discard_functions.cc
@@ -262,6 +262,15 @@
}
return nullptr;
},
+ [&](const ast::WhileStatement* s) -> const ast::Statement* {
+ if (MayDiscard(sem.Get(s->condition))) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unexpected WhileStatement condition that may discard. "
+ "Make sure transform::PromoteSideEffectsToDecl was run "
+ "first.";
+ }
+ return nullptr;
+ },
[&](const ast::IfStatement* s) -> const ast::Statement* {
auto* sem_expr = sem.Get(s->condition);
if (!MayDiscard(sem_expr)) {
diff --git a/src/tint/transform/unwind_discard_functions_test.cc b/src/tint/transform/unwind_discard_functions_test.cc
index 481df9d..7d1fa27 100644
--- a/src/tint/transform/unwind_discard_functions_test.cc
+++ b/src/tint/transform/unwind_discard_functions_test.cc
@@ -800,6 +800,67 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(UnwindDiscardFunctionsTest, While_Cond) {
+ auto* src = R"(
+fn f() -> i32 {
+ if (true) {
+ discard;
+ }
+ return 42;
+}
+
+@fragment
+fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
+ let marker1 = 0;
+ while (f() == 42) {
+ let marker2 = 0;
+ break;
+ }
+ return vec4<f32>();
+}
+)";
+ auto* expect = R"(
+var<private> tint_discard : bool = false;
+
+fn f() -> i32 {
+ if (true) {
+ tint_discard = true;
+ return i32();
+ }
+ return 42;
+}
+
+fn tint_discard_func() {
+ discard;
+}
+
+@fragment
+fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
+ let marker1 = 0;
+ loop {
+ let tint_symbol = f();
+ if (tint_discard) {
+ tint_discard_func();
+ return vec4<f32>();
+ }
+ if (!((tint_symbol == 42))) {
+ break;
+ }
+ {
+ let marker2 = 0;
+ break;
+ }
+ }
+ return vec4<f32>();
+}
+)";
+
+ DataMap data;
+ auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(UnwindDiscardFunctionsTest, Switch) {
auto* src = R"(
fn f() -> i32 {
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index 450a2e8..f5bf24b 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -22,6 +22,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/variable.h"
+#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/reverse.h"
namespace tint::transform {
@@ -46,7 +47,10 @@
};
/// For-loops that need to be decomposed to loops.
- std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
+ std::unordered_map<const sem::ForLoopStatement*, LoopInfo> for_loops;
+
+ /// Whiles that need to be decomposed to loops.
+ std::unordered_map<const sem::WhileStatement*, LoopInfo> while_loops;
/// 'else if' statements that need to be decomposed to 'else {if}'
std::unordered_map<const ast::IfStatement*, ElseIfInfo> else_ifs;
@@ -55,7 +59,7 @@
// registered declaration statements before the condition or continuing
// statement.
void ForLoopsToLoops() {
- if (loops.empty()) {
+ if (for_loops.empty()) {
return;
}
@@ -64,7 +68,7 @@
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
- if (auto it = loops.find(fl); it != loops.end()) {
+ if (auto it = for_loops.find(fl); it != for_loops.end()) {
auto& info = it->second;
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
@@ -108,6 +112,51 @@
});
}
+ // Converts any while-loops marked for conversion to loops, inserting
+ // registered declaration statements before the condition.
+ void WhilesToLoops() {
+ if (while_loops.empty()) {
+ return;
+ }
+
+ // At least one while needs to be transformed into a loop.
+ ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* {
+ auto& sem = ctx.src->Sem();
+
+ if (auto* w = sem.Get(stmt)) {
+ if (auto it = while_loops.find(w); it != while_loops.end()) {
+ auto& info = it->second;
+ auto* while_loop = w->Declaration();
+ // While needs to be decomposed to a loop.
+ // Build the loop body's statements.
+ // Start with any let declarations for the conditional
+ // expression.
+ auto body_stmts = info.cond_decls;
+ // Emit the condition as:
+ // if (!cond) { break; }
+ auto* cond = while_loop->condition;
+ // !condition
+ auto* not_cond =
+ b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
+ // { break; }
+ auto* break_body = b.Block(b.create<ast::BreakStatement>());
+ // if (!condition) { break; }
+ body_stmts.emplace_back(b.If(not_cond, break_body));
+
+ // Next emit the body
+ body_stmts.emplace_back(ctx.Clone(while_loop->body));
+
+ const ast::BlockStatement* continuing = nullptr;
+
+ auto* body = b.Block(body_stmts);
+ auto* loop = b.Loop(body, continuing);
+ return loop;
+ }
+ }
+ return nullptr;
+ });
+ }
+
void ElseIfsToElseWithNestedIfs() {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* else_if) -> const ast::Statement* {
@@ -192,7 +241,19 @@
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = loops[fl].cond_decls;
+ auto& decls = for_loops[fl].cond_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ if (auto* w = before_stmt->As<sem::WhileStatement>()) {
+ // Insertion point is a while condition.
+ // While needs to be decomposed to a loop.
+
+ // Index the map to convert this while, even if `stmt` is nullptr.
+ auto& decls = while_loops[w].cond_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@@ -227,7 +288,7 @@
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
- auto& decls = loops[fl].cont_decls;
+ auto& decls = for_loops[fl].cont_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@@ -257,6 +318,7 @@
/// @return true on success
bool Apply() {
ForLoopsToLoops();
+ WhilesToLoops();
ElseIfsToElseWithNestedIfs();
return true;
}
diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc
index 1e4cb8e..46d5551 100644
--- a/src/tint/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc
@@ -175,6 +175,47 @@
EXPECT_EQ(expect, str(cloned));
}
+TEST_F(HoistToDeclBeforeTest, WhileCond) {
+ // fn f() {
+ // var a : bool;
+ // while(a) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* expr = b.Expr("a");
+ auto* s = b.While(expr, b.Block());
+ b.Func("f", {}, b.ty.void_(), {var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* sem_expr = ctx.src->Sem().Get(expr);
+ hoistToDeclBefore.Add(sem_expr, expr, true);
+ hoistToDeclBefore.Apply();
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn f() {
+ var a : bool;
+ loop {
+ let tint_symbol = a;
+ if (!(tint_symbol)) {
+ break;
+ }
+ {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
TEST_F(HoistToDeclBeforeTest, ElseIf) {
// fn f() {
// var a : bool;
diff --git a/src/tint/transform/while_to_loop.cc b/src/tint/transform/while_to_loop.cc
new file mode 100644
index 0000000..00ebb01
--- /dev/null
+++ b/src/tint/transform/while_to_loop.cc
@@ -0,0 +1,67 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/while_to_loop.h"
+
+#include "src/tint/ast/break_statement.h"
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
+
+namespace tint::transform {
+
+WhileToLoop::WhileToLoop() = default;
+
+WhileToLoop::~WhileToLoop() = default;
+
+bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ast::WhileStatement>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
+ ast::StatementList stmts;
+ auto* cond = w->condition;
+
+ // !condition
+ auto* not_cond =
+ ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
+
+ // { break; }
+ auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
+
+ // if (!condition) { break; }
+ stmts.emplace_back(ctx.dst->If(not_cond, break_body));
+
+ for (auto* stmt : w->body->statements) {
+ stmts.emplace_back(ctx.Clone(stmt));
+ }
+
+ const ast::BlockStatement* continuing = nullptr;
+
+ auto* body = ctx.dst->Block(stmts);
+ auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
+
+ return loop;
+ });
+
+ ctx.Clone();
+}
+
+} // namespace tint::transform
diff --git a/src/tint/transform/while_to_loop.h b/src/tint/transform/while_to_loop.h
new file mode 100644
index 0000000..4915d68
--- /dev/null
+++ b/src/tint/transform/while_to_loop.h
@@ -0,0 +1,49 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
+#define SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
+
+#include "src/tint/transform/transform.h"
+
+namespace tint::transform {
+
+/// WhileToLoop is a Transform that converts a while statement into a loop
+/// statement. This is required by the SPIR-V writer.
+class WhileToLoop final : public Castable<WhileToLoop, Transform> {
+ public:
+ /// Constructor
+ WhileToLoop();
+
+ /// Destructor
+ ~WhileToLoop() override;
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::transform
+
+#endif // SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
diff --git a/src/tint/transform/while_to_loop_test.cc b/src/tint/transform/while_to_loop_test.cc
new file mode 100644
index 0000000..6e5699d
--- /dev/null
+++ b/src/tint/transform/while_to_loop_test.cc
@@ -0,0 +1,129 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/while_to_loop.h"
+
+#include "src/tint/transform/test_helper.h"
+
+namespace tint::transform {
+namespace {
+
+using WhileToLoopTest = TransformTest;
+
+TEST_F(WhileToLoopTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<WhileToLoop>(src));
+}
+
+TEST_F(WhileToLoopTest, ShouldRunHasWhile) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ break;
+ }
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<WhileToLoop>(src));
+}
+
+TEST_F(WhileToLoopTest, EmptyModule) {
+ auto* src = "";
+ auto* expect = src;
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test an empty for loop.
+TEST_F(WhileToLoopTest, Empty) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ break;
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a for loop with non-empty body.
+TEST_F(WhileToLoopTest, Body) {
+ auto* src = R"(
+fn f() {
+ while (true) {
+ discard;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!(true)) {
+ break;
+ }
+ discard;
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test a loop with a break condition
+TEST_F(WhileToLoopTest, BreakCondition) {
+ auto* src = R"(
+fn f() {
+ while (0 == 1) {
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ loop {
+ if (!((0 == 1))) {
+ break;
+ }
+ }
+}
+)";
+
+ auto got = Run<WhileToLoop>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+
+} // namespace tint::transform
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 82fb78e..441b0d4 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -1753,7 +1753,7 @@
}
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
- if (!emit_continuing_()) {
+ if (!emit_continuing_ || !emit_continuing_()) {
return false;
}
line() << "continue;";
@@ -2523,6 +2523,57 @@
return true;
}
+bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
+ TextBuffer cond_pre;
+ std::stringstream cond_buf;
+ {
+ auto* cond = stmt->condition;
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
+ if (!EmitExpression(cond_buf, cond)) {
+ return false;
+ }
+ }
+
+ auto emit_continuing = [&]() { return true; };
+ TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
+
+ // If the whilehas a multi-statement conditional, then we cannot emit this
+ // as a regular while in GLSL. Instead we need to generate a `while(true)` loop.
+ bool emit_as_loop = cond_pre.lines.size() > 0;
+ if (emit_as_loop) {
+ line() << "while (true) {";
+ increment_indent();
+ TINT_DEFER({
+ decrement_indent();
+ line() << "}";
+ });
+
+ current_buffer_->Append(cond_pre);
+ line() << "if (!(" << cond_buf.str() << ")) { break; }";
+
+ if (!EmitStatements(stmt->body->statements)) {
+ return false;
+ }
+ } else {
+ // While can be generated.
+ {
+ auto out = line();
+ out << "while";
+ {
+ ScopedParen sp(out);
+ out << cond_buf.str();
+ }
+ out << " {";
+ }
+ if (!EmitStatementsWithIndent(stmt->body->statements)) {
+ return false;
+ }
+ line() << "}";
+ }
+
+ return true;
+}
+
bool GeneratorImpl::EmitMemberAccessor(std::ostream& out,
const ast::MemberAccessorExpression* expr) {
if (!EmitExpression(out, expr->structure)) {
@@ -2591,6 +2642,9 @@
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
+ if (auto* l = stmt->As<ast::WhileStatement>()) {
+ return EmitWhile(l);
+ }
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 819c79b..ff1611c 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -357,6 +357,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
+ /// Handles a while statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted
+ bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles generating an identifier expression
/// @param out the output of the expression stream
/// @param expr the identifier expression
diff --git a/src/tint/writer/glsl/generator_impl_loop_test.cc b/src/tint/writer/glsl/generator_impl_loop_test.cc
index 5187daf..ec9219a 100644
--- a/src/tint/writer/glsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/glsl/generator_impl_loop_test.cc
@@ -381,5 +381,71 @@
)");
}
+TEST_F(GlslGeneratorImplTest_Loop, Emit_While) {
+ // while(true) {
+ // return;
+ // }
+
+ auto* f = While(Expr(true), Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ return;
+ }
+)");
+}
+
+TEST_F(GlslGeneratorImplTest_Loop, Emit_While_WithContinue) {
+ // while(true) {
+ // continue;
+ // }
+
+ auto* f = While(Expr(true), Block(Continue()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ continue;
+ }
+)");
+}
+
+TEST_F(GlslGeneratorImplTest_Loop, Emit_WhileWithMultiStmtCond) {
+ // while(true && false) {
+ // return;
+ // }
+
+ Func("a_statement", {}, ty.void_(), {});
+
+ auto* multi_stmt =
+ create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
+ auto* f = While(multi_stmt, Block(CallStmt(Call("a_statement"))));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while (true) {
+ bool tint_tmp = true;
+ if (tint_tmp) {
+ tint_tmp = false;
+ }
+ if (!((tint_tmp))) { break; }
+ a_statement();
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::glsl
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index a9e290c..b76b9e6 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -2634,7 +2634,7 @@
}
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
- if (!emit_continuing_()) {
+ if (!emit_continuing_ || !emit_continuing_()) {
return false;
}
line() << "continue;";
@@ -3481,6 +3481,56 @@
return true;
}
+bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
+ TextBuffer cond_pre;
+ std::stringstream cond_buf;
+ {
+ auto* cond = stmt->condition;
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
+ if (!EmitExpression(cond_buf, cond)) {
+ return false;
+ }
+ }
+
+ auto emit_continuing = [&]() { return true; };
+ TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
+
+ // If the while has a multi-statement conditional, then we cannot emit this
+ // as a regular while in HLSL. Instead we need to generate a `while(true)` loop.
+ bool emit_as_loop = cond_pre.lines.size() > 0;
+ if (emit_as_loop) {
+ line() << LoopAttribute() << "while (true) {";
+ increment_indent();
+ TINT_DEFER({
+ decrement_indent();
+ line() << "}";
+ });
+
+ current_buffer_->Append(cond_pre);
+ line() << "if (!(" << cond_buf.str() << ")) { break; }";
+ if (!EmitStatements(stmt->body->statements)) {
+ return false;
+ }
+ } else {
+ // While can be generated.
+ {
+ auto out = line();
+ out << LoopAttribute() << "while";
+ {
+ ScopedParen sp(out);
+ out << cond_buf.str();
+ }
+ out << " {";
+ }
+ if (!EmitStatementsWithIndent(stmt->body->statements)) {
+ return false;
+ }
+ line() << "}";
+ }
+
+ return true;
+}
+
bool GeneratorImpl::EmitMemberAccessor(std::ostream& out,
const ast::MemberAccessorExpression* expr) {
if (!EmitExpression(out, expr->structure)) {
@@ -3551,6 +3601,9 @@
[&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l);
},
+ [&](const ast::WhileStatement* l) { //
+ return EmitWhile(l);
+ },
[&](const ast::ReturnStatement* r) { //
return EmitReturn(r);
},
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index af7e4c9..0205142 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -353,6 +353,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
+ /// Handles a while statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted
+ bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles generating an identifier expression
/// @param out the output of the expression stream
/// @param expr the identifier expression
diff --git a/src/tint/writer/hlsl/generator_impl_loop_test.cc b/src/tint/writer/hlsl/generator_impl_loop_test.cc
index 0bf4090..29d1822 100644
--- a/src/tint/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_loop_test.cc
@@ -373,5 +373,69 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Loop, Emit_While) {
+ // while(true) {
+ // return;
+ // }
+
+ auto* f = While(Expr(true), Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( [loop] while(true) {
+ return;
+ }
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Loop, Emit_While_WithContinue) {
+ // while(true) {
+ // continue;
+ // }
+
+ auto* f = While(Expr(true), Block(Continue()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( [loop] while(true) {
+ continue;
+ }
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Loop, Emit_WhileWithMultiStmtCond) {
+ // while(true && false) {
+ // return;
+ // }
+
+ auto* multi_stmt =
+ create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
+ auto* f = While(multi_stmt, Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( [loop] while (true) {
+ bool tint_tmp = true;
+ if (tint_tmp) {
+ tint_tmp = false;
+ }
+ if (!((tint_tmp))) { break; }
+ return;
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::hlsl
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index aca2227..e5ba3a8 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1521,7 +1521,7 @@
}
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
- if (!emit_continuing_()) {
+ if (!emit_continuing_ || !emit_continuing_()) {
return false;
}
@@ -2124,6 +2124,56 @@
return true;
}
+bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
+ TextBuffer cond_pre;
+ std::stringstream cond_buf;
+
+ {
+ auto* cond = stmt->condition;
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
+ if (!EmitExpression(cond_buf, cond)) {
+ return false;
+ }
+ }
+
+ auto emit_continuing = [&]() { return true; };
+ TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
+
+ // If the while has a multi-statement conditional, then we cannot emit this
+ // as a regular while in MSL. Instead we need to generate a `while(true)` loop.
+ bool emit_as_loop = cond_pre.lines.size() > 0;
+ if (emit_as_loop) {
+ line() << "while (true) {";
+ increment_indent();
+ TINT_DEFER({
+ decrement_indent();
+ line() << "}";
+ });
+
+ current_buffer_->Append(cond_pre);
+ line() << "if (!(" << cond_buf.str() << ")) { break; }";
+ if (!EmitStatements(stmt->body->statements)) {
+ return false;
+ }
+ } else {
+ // While can be generated.
+ {
+ auto out = line();
+ out << "while";
+ {
+ ScopedParen sp(out);
+ out << cond_buf.str();
+ }
+ out << " {";
+ }
+ if (!EmitStatementsWithIndent(stmt->body->statements)) {
+ return false;
+ }
+ line() << "}";
+ }
+ return true;
+}
+
bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
// TODO(dsinclair): Verify this is correct when the discard semantics are
// defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
@@ -2280,6 +2330,9 @@
[&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l);
},
+ [&](const ast::WhileStatement* l) { //
+ return EmitWhile(l);
+ },
[&](const ast::ReturnStatement* r) { //
return EmitReturn(r);
},
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index be98a86..a05f3b1 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -270,6 +270,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
+ /// Handles a while statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emitted
+ bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles a member accessor expression
/// @param out the output of the expression stream
/// @param expr the member accessor expression
diff --git a/src/tint/writer/msl/generator_impl_loop_test.cc b/src/tint/writer/msl/generator_impl_loop_test.cc
index 248e711..85f1deb 100644
--- a/src/tint/writer/msl/generator_impl_loop_test.cc
+++ b/src/tint/writer/msl/generator_impl_loop_test.cc
@@ -344,5 +344,64 @@
)");
}
+TEST_F(MslGeneratorImplTest, Emit_While) {
+ // while(true) {
+ // return;
+ // }
+
+ auto* f = While(Expr(true), Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ return;
+ }
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_While_WithContinue) {
+ // while(true) {
+ // continue;
+ // }
+
+ auto* f = While(Expr(true), Block(Continue()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ continue;
+ }
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_WhileWithMultiCond) {
+ // while(true && false) {
+ // return;
+ // }
+
+ auto* multi_stmt =
+ create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
+ auto* f = While(multi_stmt, Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while((true && false)) {
+ return;
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::msl
diff --git a/src/tint/writer/spirv/generator_impl.cc b/src/tint/writer/spirv/generator_impl.cc
index b8ca89c..c5e9dad 100644
--- a/src/tint/writer/spirv/generator_impl.cc
+++ b/src/tint/writer/spirv/generator_impl.cc
@@ -32,6 +32,7 @@
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/var_for_dynamic_index.h"
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
+#include "src/tint/transform/while_to_loop.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
#include "src/tint/writer/generate_external_texture_bindings.h"
@@ -74,7 +75,7 @@
manager.Add<transform::SimplifyPointers>(); // Required for arrayLength()
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::ForLoopToLoop>(); // Must come after
- // ZeroInitWorkgroupMemory
+ manager.Add<transform::WhileToLoop>(); // ZeroInitWorkgroupMemory
manager.Add<transform::CanonicalizeEntryPointIO>();
manager.Add<transform::AddEmptyEntryPoint>();
manager.Add<transform::AddSpirvBlockAttribute>();
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 5e0ce8c..35119ac 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -919,6 +919,7 @@
[&](const ast::IncrementDecrementStatement* l) { return EmitIncrementDecrement(l); },
[&](const ast::LoopStatement* l) { return EmitLoop(l); },
[&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
+ [&](const ast::WhileStatement* l) { return EmitWhile(l); },
[&](const ast::ReturnStatement* r) { return EmitReturn(r); },
[&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
[&](const ast::VariableDeclStatement* v) { return EmitVariable(line(), v->variable); },
@@ -1181,6 +1182,30 @@
return true;
}
+bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
+ {
+ auto out = line();
+ out << "while";
+ {
+ ScopedParen sp(out);
+
+ auto* cond = stmt->condition;
+ if (!EmitExpression(out, cond)) {
+ return false;
+ }
+ }
+ out << " {";
+ }
+
+ if (!EmitStatementsWithIndent(stmt->body->statements)) {
+ return false;
+ }
+
+ line() << "}";
+
+ return true;
+}
+
bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
auto out = line();
out << "return";
diff --git a/src/tint/writer/wgsl/generator_impl.h b/src/tint/writer/wgsl/generator_impl.h
index a17e2da..8ceeab2 100644
--- a/src/tint/writer/wgsl/generator_impl.h
+++ b/src/tint/writer/wgsl/generator_impl.h
@@ -152,6 +152,10 @@
/// @param stmt the statement to emit
/// @returns true if the statement was emtited
bool EmitForLoop(const ast::ForLoopStatement* stmt);
+ /// Handles a while statement
+ /// @param stmt the statement to emit
+ /// @returns true if the statement was emtited
+ bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles a member accessor expression
/// @param out the output of the expression stream
/// @param expr the member accessor expression
diff --git a/src/tint/writer/wgsl/generator_impl_loop_test.cc b/src/tint/writer/wgsl/generator_impl_loop_test.cc
index 2d6a8f4..3dbae60 100644
--- a/src/tint/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_loop_test.cc
@@ -198,5 +198,64 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_While) {
+ // while(true) {
+ // return;
+ // }
+
+ auto* f = While(Expr(true), Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ return;
+ }
+)");
+}
+
+TEST_F(WgslGeneratorImplTest, Emit_While_WithContinue) {
+ // while(true) {
+ // continue;
+ // }
+
+ auto* f = While(Expr(true), Block(Continue()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while(true) {
+ continue;
+ }
+)");
+}
+
+TEST_F(WgslGeneratorImplTest, Emit_WhileMultiCond) {
+ // while(true && false) {
+ // return;
+ // }
+
+ auto* multi_stmt =
+ create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
+ auto* f = While(multi_stmt, Block(Return()));
+ WrapInFunction(f);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( while((true && false)) {
+ return;
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::wgsl
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index c75f50e..3bb31c3 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -211,6 +211,7 @@
"../../src/tint/ast/variable_decl_statement_test.cc",
"../../src/tint/ast/variable_test.cc",
"../../src/tint/ast/vector_test.cc",
+ "../../src/tint/ast/while_statement_test.cc",
"../../src/tint/ast/workgroup_attribute_test.cc",
]
}
@@ -307,8 +308,8 @@
"../../src/tint/sem/sem_struct_test.cc",
"../../src/tint/sem/storage_texture_test.cc",
"../../src/tint/sem/texture_test.cc",
- "../../src/tint/sem/type_test.cc",
"../../src/tint/sem/type_manager_test.cc",
+ "../../src/tint/sem/type_test.cc",
"../../src/tint/sem/u32_test.cc",
"../../src/tint/sem/vector_test.cc",
]
@@ -359,6 +360,7 @@
"../../src/tint/transform/var_for_dynamic_index_test.cc",
"../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",
"../../src/tint/transform/vertex_pulling_test.cc",
+ "../../src/tint/transform/while_to_loop_test.cc",
"../../src/tint/transform/wrap_arrays_in_structs_test.cc",
"../../src/tint/transform/zero_init_workgroup_memory_test.cc",
]
@@ -552,6 +554,7 @@
"../../src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc",
"../../src/tint/reader/wgsl/parser_impl_variable_qualifier_test.cc",
"../../src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc",
+ "../../src/tint/reader/wgsl/parser_impl_while_stmt_test.cc",
"../../src/tint/reader/wgsl/parser_test.cc",
"../../src/tint/reader/wgsl/token_test.cc",
]