Import Tint changes from Dawn
Changes:
- 99084a411df0454ca5d144f0e574b4a0e54b429d tint/resolver: fix diagnostic source for test by Ben Clayton <bclayton@google.com>
- 72ac53e5fa8c7837750a3dfec350775da72e8aef Convert binding and group attributes to expressions. by dan sinclair <dsinclair@chromium.org>
- 308c55d9e0cbad1cf2a1f11ac8acb3face71ceee Convert `size` attribute to expressions. by dan sinclair <dsinclair@chromium.org>
- d9222f44c9b1f020fabb35e94e5c436320815537 tint/resolver: Validate discard is only used by fragment ... by Ben Clayton <bclayton@google.com>
- 1a567780d9b2c32d0ce1e8dfa5846a382139b643 tint/writer: Check extensions are supported by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 99084a411df0454ca5d144f0e574b4a0e54b429d
Change-Id: I74562d1c9fb6fdcf36a754bf76f1ef285619c2f2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/106080
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Copybara Prod <copybara-worker-blackhole@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index e034c34..e87e7cd 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -585,6 +585,8 @@
"writer/append_vector.h",
"writer/array_length_from_uniform_options.cc",
"writer/array_length_from_uniform_options.h",
+ "writer/check_supported_extensions.cc",
+ "writer/check_supported_extensions.h",
"writer/flatten_bindings.cc",
"writer/flatten_bindings.h",
"writer/float_to_string.cc",
@@ -1271,6 +1273,7 @@
tint_unittests_source_set("tint_unittests_writer_src") {
sources = [
"writer/append_vector_test.cc",
+ "writer/check_supported_extensions_test.cc",
"writer/flatten_bindings_test.cc",
"writer/float_to_string_test.cc",
"writer/generate_external_texture_bindings_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index fc62717..b3dd6ce 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -495,6 +495,8 @@
writer/append_vector.h
writer/array_length_from_uniform_options.cc
writer/array_length_from_uniform_options.h
+ writer/check_supported_extensions.cc
+ writer/check_supported_extensions.h
writer/flatten_bindings.cc
writer/flatten_bindings.h
writer/float_to_string.cc
@@ -889,6 +891,7 @@
utils/unique_vector_test.cc
utils/vector_test.cc
writer/append_vector_test.cc
+ writer/check_supported_extensions_test.cc
writer/flatten_bindings_test.cc
writer/float_to_string_test.cc
writer/generate_external_texture_bindings_test.cc
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 8430698..f300710 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -1642,6 +1642,16 @@
return enable;
}
+ /// Adds the extension to the list of enable directives at the top of the module.
+ /// @param source the enable source
+ /// @param ext the extension to enable
+ /// @return an `ast::Enable` enabling the given extension.
+ const ast::Enable* Enable(const Source& source, ast::Extension ext) {
+ auto* enable = create<ast::Enable>(source, ext);
+ AST().AddEnable(enable);
+ return enable;
+ }
+
/// @param name the variable name
/// @param options the extra options passed to the ast::Var constructor
/// Can be any of the following, in any order:
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 3dad333..7c06020 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3441,15 +3441,16 @@
if (t == "binding") {
const char* use = "binding attribute";
return expect_paren_block(use, [&]() -> Result {
- auto val = expect_positive_sint(use);
- if (val.errored) {
+ auto expr = expression();
+ if (expr.errored) {
return Failure::kErrored;
}
+ if (!expr.matched) {
+ return add_error(peek(), "expected binding expression");
+ }
match(Token::Type::kComma);
- return create<ast::BindingAttribute>(
- t.source(), create<ast::IntLiteralExpression>(
- val.value, ast::IntLiteralExpression::Suffix::kNone));
+ return create<ast::BindingAttribute>(t.source(), expr.value);
});
}
@@ -3478,15 +3479,16 @@
if (t == "group") {
const char* use = "group attribute";
return expect_paren_block(use, [&]() -> Result {
- auto val = expect_positive_sint(use);
- if (val.errored) {
+ auto expr = expression();
+ if (expr.errored) {
return Failure::kErrored;
}
+ if (!expr.matched) {
+ return add_error(peek(), "expected group expression");
+ }
match(Token::Type::kComma);
- return create<ast::GroupAttribute>(
- t.source(), create<ast::IntLiteralExpression>(
- val.value, ast::IntLiteralExpression::Suffix::kNone));
+ return create<ast::GroupAttribute>(t.source(), expr.value);
});
}
@@ -3551,13 +3553,16 @@
if (t == "size") {
const char* use = "size attribute";
return expect_paren_block(use, [&]() -> Result {
- auto val = expect_positive_sint(use);
- if (val.errored) {
+ auto expr = expression();
+ if (expr.errored) {
return Failure::kErrored;
}
+ if (!expr.matched) {
+ return add_error(peek(), "expected size expression");
+ }
match(Token::Type::kComma);
- return builder_.MemberSize(t.source(), AInt(val.value));
+ return builder_.MemberSize(t.source(), expr.value);
});
}
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index af9734c..0fc27dd 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -862,17 +862,9 @@
}
TEST_F(ParserImplErrorTest, GlobalDeclStructMemberSizeInvaldValue) {
- EXPECT("struct S { @size(x) i : i32, };",
- R"(test.wgsl:1:18 error: expected signed integer literal for size attribute
-struct S { @size(x) i : i32, };
- ^
-)");
-}
-
-TEST_F(ParserImplErrorTest, GlobalDeclStructMemberSizeNegativeValue) {
- EXPECT("struct S { @size(-2) i : i32, };",
- R"(test.wgsl:1:18 error: size attribute must be positive
-struct S { @size(-2) i : i32, };
+ EXPECT("struct S { @size(if) i : i32, };",
+ R"(test.wgsl:1:18 error: expected size expression
+struct S { @size(if) i : i32, };
^^
)");
}
@@ -1025,10 +1017,10 @@
}
TEST_F(ParserImplErrorTest, GlobalDeclVarAttrBindingInvalidValue) {
- EXPECT("@binding(x) var i : i32;",
- R"(test.wgsl:1:10 error: expected signed integer literal for binding attribute
-@binding(x) var i : i32;
- ^
+ EXPECT("@binding(if) var i : i32;",
+ R"(test.wgsl:1:10 error: expected binding expression
+@binding(if) var i : i32;
+ ^^
)");
}
@@ -1049,10 +1041,10 @@
}
TEST_F(ParserImplErrorTest, GlobalDeclVarAttrBindingGroupValue) {
- EXPECT("@group(x) var i : i32;",
- R"(test.wgsl:1:8 error: expected signed integer literal for group attribute
-@group(x) var i : i32;
- ^
+ EXPECT("@group(if) var i : i32;",
+ R"(test.wgsl:1:8 error: expected group expression
+@group(if) var i : i32;
+ ^^
)");
}
diff --git a/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc
index 388ad67..5f051db 100644
--- a/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc
@@ -139,7 +139,7 @@
EXPECT_NE(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:10: expected signed integer literal for binding attribute");
+ EXPECT_EQ(p->error(), "1:10: expected binding expression");
}
TEST_F(ParserImplTest, GlobalVariableDecl_InvalidConstExpr) {
diff --git a/src/tint/reader/wgsl/parser_impl_struct_body_decl_test.cc b/src/tint/reader/wgsl/parser_impl_struct_body_decl_test.cc
index ce56a67..e994230 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_body_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_body_decl_test.cc
@@ -71,12 +71,12 @@
TEST_F(ParserImplTest, StructBodyDecl_InvalidSize) {
auto p = parser(R"(
{
- @size(nan) a : i32,
+ @size(if) a : i32,
})");
auto m = p->expect_struct_body_decl();
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(m.errored);
- EXPECT_EQ(p->error(), "3:9: expected signed integer literal for size attribute");
+ EXPECT_EQ(p->error(), "3:9: expected size expression");
}
TEST_F(ParserImplTest, StructBodyDecl_MissingClosingBracket) {
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_decl_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_decl_test.cc
index 09d5274..0417f7c 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_decl_test.cc
@@ -39,12 +39,12 @@
}
TEST_F(ParserImplTest, AttributeDecl_InvalidAttribute) {
- auto p = parser("@size(nan)");
+ auto p = parser("@size(if)");
auto attrs = p->attribute_list();
EXPECT_TRUE(p->has_error()) << p->error();
EXPECT_TRUE(attrs.errored);
EXPECT_FALSE(attrs.matched);
- EXPECT_EQ(p->error(), "1:7: expected signed integer literal for size attribute");
+ EXPECT_EQ(p->error(), "1:7: expected size expression");
}
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
index f6957e4..80376d8 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
@@ -23,7 +23,7 @@
EXPECT_TRUE(attr.matched);
EXPECT_FALSE(attr.errored);
ASSERT_NE(attr.value, nullptr);
- ASSERT_FALSE(p->has_error());
+ ASSERT_FALSE(p->has_error()) << p->error();
auto* member_attr = attr.value->As<ast::Attribute>();
ASSERT_NE(member_attr, nullptr);
@@ -34,13 +34,39 @@
EXPECT_EQ(o->expr->As<ast::IntLiteralExpression>()->value, 4u);
}
+TEST_F(ParserImplTest, Attribute_Size_Expression) {
+ auto p = parser("size(4 + 5)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ ASSERT_FALSE(p->has_error()) << p->error();
+
+ auto* member_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(member_attr, nullptr);
+ ASSERT_TRUE(member_attr->Is<ast::StructMemberSizeAttribute>());
+
+ auto* o = member_attr->As<ast::StructMemberSizeAttribute>();
+ ASSERT_TRUE(o->expr->Is<ast::BinaryExpression>());
+ auto* expr = o->expr->As<ast::BinaryExpression>();
+
+ EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
+ auto* v = expr->lhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 4u);
+
+ v = expr->rhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 5u);
+}
+
TEST_F(ParserImplTest, Attribute_Size_TrailingComma) {
auto p = parser("size(4,)");
auto attr = p->attribute();
EXPECT_TRUE(attr.matched);
EXPECT_FALSE(attr.errored);
ASSERT_NE(attr.value, nullptr);
- ASSERT_FALSE(p->has_error());
+ ASSERT_FALSE(p->has_error()) << p->error();
auto* member_attr = attr.value->As<ast::Attribute>();
ASSERT_NE(member_attr, nullptr);
@@ -78,17 +104,17 @@
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:6: expected signed integer literal for size attribute");
+ EXPECT_EQ(p->error(), "1:6: expected size expression");
}
TEST_F(ParserImplTest, Attribute_Size_MissingInvalid) {
- auto p = parser("size(nan)");
+ auto p = parser("size(if)");
auto attr = p->attribute();
EXPECT_FALSE(attr.matched);
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:6: expected signed integer literal for size attribute");
+ EXPECT_EQ(p->error(), "1:6: expected size expression");
}
TEST_F(ParserImplTest, Attribute_Align) {
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
index 6267406..53174f6 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
@@ -115,14 +115,14 @@
}
TEST_F(ParserImplTest, StructMember_InvalidAttribute) {
- auto p = parser("@size(nan) a : i32,");
+ auto p = parser("@size(if) a : i32,");
auto m = p->expect_struct_member();
ASSERT_TRUE(m.errored);
ASSERT_EQ(m.value, nullptr);
ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:7: expected signed integer literal for size attribute");
+ EXPECT_EQ(p->error(), "1:7: expected size expression");
}
} // namespace
diff --git a/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc
index 4c68f9b..86e5487 100644
--- a/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_variable_attribute_test.cc
@@ -392,6 +392,31 @@
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
+TEST_F(ParserImplTest, Attribute_Binding_Expression) {
+ auto p = parser("binding(4 + 5)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ auto* var_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(var_attr, nullptr);
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(var_attr->Is<ast::BindingAttribute>());
+
+ auto* binding = var_attr->As<ast::BindingAttribute>();
+ ASSERT_TRUE(binding->expr->Is<ast::BinaryExpression>());
+ auto* expr = binding->expr->As<ast::BinaryExpression>();
+
+ EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
+ auto* v = expr->lhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 4u);
+
+ v = expr->rhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 5u);
+}
+
TEST_F(ParserImplTest, Attribute_Binding_TrailingComma) {
auto p = parser("binding(4,)");
auto attr = p->attribute();
@@ -437,17 +462,17 @@
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:9: expected signed integer literal for binding attribute");
+ EXPECT_EQ(p->error(), "1:9: expected binding expression");
}
TEST_F(ParserImplTest, Attribute_Binding_MissingInvalid) {
- auto p = parser("binding(nan)");
+ auto p = parser("binding(if)");
auto attr = p->attribute();
EXPECT_FALSE(attr.matched);
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:9: expected signed integer literal for binding attribute");
+ EXPECT_EQ(p->error(), "1:9: expected binding expression");
}
TEST_F(ParserImplTest, Attribute_group) {
@@ -468,6 +493,31 @@
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
+TEST_F(ParserImplTest, Attribute_group_expression) {
+ auto p = parser("group(4 + 5)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ auto* var_attr = attr.value->As<ast::Attribute>();
+ ASSERT_FALSE(p->has_error());
+ ASSERT_NE(var_attr, nullptr);
+ ASSERT_TRUE(var_attr->Is<ast::GroupAttribute>());
+
+ auto* group = var_attr->As<ast::GroupAttribute>();
+ ASSERT_TRUE(group->expr->Is<ast::BinaryExpression>());
+ auto* expr = group->expr->As<ast::BinaryExpression>();
+
+ EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
+ auto* v = expr->lhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 4u);
+
+ v = expr->rhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 5u);
+}
+
TEST_F(ParserImplTest, Attribute_group_TrailingComma) {
auto p = parser("group(4,)");
auto attr = p->attribute();
@@ -513,17 +563,17 @@
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:7: expected signed integer literal for group attribute");
+ EXPECT_EQ(p->error(), "1:7: expected group expression");
}
TEST_F(ParserImplTest, Attribute_Group_MissingInvalid) {
- auto p = parser("group(nan)");
+ auto p = parser("group(if)");
auto attr = p->attribute();
EXPECT_FALSE(attr.matched);
EXPECT_TRUE(attr.errored);
EXPECT_EQ(attr.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:7: expected signed integer literal for group attribute");
+ EXPECT_EQ(p->error(), "1:7: expected group expression");
}
} // namespace
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 30615bb..552b470 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -735,12 +735,85 @@
TEST_F(StructMemberAttributeTest, Align_Attribute_Override) {
Override("val", ty.f32(), Expr(1.23_f));
- Structure("mystruct", utils::Vector{Member(
- "a", ty.f32(), utils::Vector{MemberAlign(Source{{12, 34}}, "val")})});
+ Structure("mystruct",
+ utils::Vector{Member("a", ty.f32(),
+ utils::Vector{MemberAlign(Expr(Source{{12, 34}}, "val"))})});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(error: @align requires a const-expression, but expression is an override-expression)");
+ R"(12:34 error: @align requires a const-expression, but expression is an override-expression)");
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_Const) {
+ GlobalConst("val", ty.i32(), Expr(4_i));
+
+ Structure("mystruct", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize("val")})});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_ConstNegative) {
+ GlobalConst("val", ty.i32(), Expr(-2_i));
+
+ Structure("mystruct", utils::Vector{Member(
+ "a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, "val")})});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'size' attribute must be positive)");
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_ConstF32) {
+ GlobalConst("val", ty.f32(), Expr(1.23_f));
+
+ Structure("mystruct", utils::Vector{Member(
+ "a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, "val")})});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'size' must be an i32 or u32 value)");
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_ConstU32) {
+ GlobalConst("val", ty.u32(), Expr(4_u));
+
+ Structure("mystruct", utils::Vector{Member(
+ "a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, "val")})});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_ConstAInt) {
+ GlobalConst("val", Expr(4_a));
+
+ Structure("mystruct", utils::Vector{Member(
+ "a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, "val")})});
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_ConstAFloat) {
+ GlobalConst("val", Expr(2.0_a));
+
+ Structure("mystruct", utils::Vector{Member(
+ "a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, "val")})});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'size' must be an i32 or u32 value)");
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_Var) {
+ GlobalVar(Source{{1, 2}}, "val", ty.f32(), ast::AddressSpace::kPrivate, ast::Access::kUndefined,
+ Expr(1.23_f));
+
+ Structure(Source{{6, 4}}, "mystruct",
+ utils::Vector{Member(Source{{12, 5}}, "a", ty.f32(),
+ utils::Vector{MemberSize(Expr(Source{{12, 35}}, "val"))})});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:35 error: var 'val' cannot be referenced at module-scope
+1:2 note: var 'val' declared here)");
+}
+
+TEST_F(StructMemberAttributeTest, Size_Attribute_Override) {
+ Override("val", ty.f32(), Expr(1.23_f));
+
+ Structure("mystruct", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize("val")})});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: @size requires a const-expression, but expression is an override-expression)");
}
} // namespace StructAndStructMemberTests
@@ -1490,6 +1563,83 @@
R"(12:34 error: interpolate attribute must only be used with @location)");
}
+using GroupAndBindingTest = ResolverTest;
+
+TEST_F(GroupAndBindingTest, Const_I32) {
+ GlobalConst("b", Expr(4_i));
+ GlobalConst("g", Expr(2_i));
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding("b"),
+ Group("g"));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(GroupAndBindingTest, Const_U32) {
+ GlobalConst("b", Expr(4_u));
+ GlobalConst("g", Expr(2_u));
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding("b"),
+ Group("g"));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(GroupAndBindingTest, Const_AInt) {
+ GlobalConst("b", Expr(4_a));
+ GlobalConst("g", Expr(2_a));
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding("b"),
+ Group("g"));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(GroupAndBindingTest, Binding_Negative) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ Binding(Source{{12, 34}}, -2_i), Group(1_i));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'binding' value must be non-negative)");
+}
+
+TEST_F(GroupAndBindingTest, Binding_F32) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ Binding(Source{{12, 34}}, 2.0_f), Group(1_u));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'binding' must be an i32 or u32 value)");
+}
+
+TEST_F(GroupAndBindingTest, Binding_AFloat) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()),
+ Binding(Source{{12, 34}}, 2.0_a), Group(1_u));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'binding' must be an i32 or u32 value)");
+}
+
+TEST_F(GroupAndBindingTest, Group_Negative) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding(2_u),
+ Group(Source{{12, 34}}, -1_i));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'group' value must be non-negative)");
+}
+
+TEST_F(GroupAndBindingTest, Group_F32) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding(2_u),
+ Group(Source{{12, 34}}, 1.0_f));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'group' must be an i32 or u32 value)");
+}
+
+TEST_F(GroupAndBindingTest, Group_AFloat) {
+ GlobalVar("val", ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()), Binding(2_u),
+ Group(Source{{12, 34}}, 1.0_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: 'group' must be an i32 or u32 value)");
+}
+
} // namespace
} // namespace InterpolateTests
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 0b391c5..514e0cb 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -198,6 +198,7 @@
[&](const ast::Variable* var) {
Declare(var->symbol, var);
TraverseType(var->type);
+ TraverseAttributes(var->attributes);
if (var->constructor) {
TraverseExpression(var->constructor);
}
@@ -416,22 +417,38 @@
/// Traverses the attribute, performing symbol resolution and determining
/// global dependencies.
void TraverseAttribute(const ast::Attribute* attr) {
- if (auto* wg = attr->As<ast::WorkgroupAttribute>()) {
- TraverseExpression(wg->x);
- TraverseExpression(wg->y);
- TraverseExpression(wg->z);
- return;
- }
- if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) {
- TraverseExpression(align->expr);
+ bool handled = Switch(
+ attr,
+ [&](const ast::BindingAttribute* binding) {
+ TraverseExpression(binding->expr);
+ return true;
+ },
+ [&](const ast::GroupAttribute* group) {
+ TraverseExpression(group->expr);
+ return true;
+ },
+ [&](const ast::StructMemberAlignAttribute* align) {
+ TraverseExpression(align->expr);
+ return true;
+ },
+ [&](const ast::StructMemberSizeAttribute* size) {
+ TraverseExpression(size->expr);
+ return true;
+ },
+ [&](const ast::WorkgroupAttribute* wg) {
+ TraverseExpression(wg->x);
+ TraverseExpression(wg->y);
+ TraverseExpression(wg->z);
+ return true;
+ });
+ if (handled) {
return;
}
- if (attr->IsAnyOf<ast::BindingAttribute, ast::BuiltinAttribute, ast::GroupAttribute,
- ast::IdAttribute, ast::InternalAttribute, ast::InterpolateAttribute,
- ast::InvariantAttribute, ast::LocationAttribute, ast::StageAttribute,
- ast::StrideAttribute, ast::StructMemberOffsetAttribute,
- ast::StructMemberSizeAttribute>()) {
+ if (attr->IsAnyOf<ast::BuiltinAttribute, ast::IdAttribute, ast::InternalAttribute,
+ ast::InterpolateAttribute, ast::InvariantAttribute,
+ ast::LocationAttribute, ast::StageAttribute, ast::StrideAttribute,
+ ast::StructMemberOffsetAttribute>()) {
return;
}
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 614873e..150d65a 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1230,9 +1230,11 @@
Alias(Sym(), T);
Structure(Sym(), //
- utils::Vector{
- Member(Sym(), T, utils::Vector{MemberAlign(V)}) //
- });
+ utils::Vector{Member(Sym(), T,
+ utils::Vector{
+ //
+ MemberAlign(V), MemberSize(V) //
+ })});
GlobalVar(Sym(), T, V);
GlobalConst(Sym(), T, V);
Func(Sym(), //
@@ -1287,6 +1289,9 @@
GlobalVar(Sym(), ty.storage_texture(ast::TextureDimension::k2d, ast::TexelFormat::kR32Float,
ast::Access::kRead)); //
GlobalVar(Sym(), ty.sampler(ast::SamplerKind::kSampler));
+
+ GlobalVar(Sym(), ty.i32(), utils::Vector{Binding(V), Group(V)});
+
Func(Sym(), utils::Empty, ty.void_(), utils::Empty);
#undef V
#undef T
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 9fb0da4..be7dcd2 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -143,7 +143,7 @@
TEST_F(ResolverFunctionValidationTest, UnreachableCode_return_InBlocks) {
// fn func() -> {
// var a : i32;
- // utils::Vector {{{return;}}}
+ // {{{return;}}}
// a = 2i;
//}
@@ -184,7 +184,7 @@
TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) {
// fn func() -> {
// var a : i32;
- // utils::Vector {{{discard;}}}
+ // {{{discard;}}}
// a = 2i;
//}
@@ -202,6 +202,59 @@
EXPECT_FALSE(Sem().Get(assign_a)->IsReachable());
}
+TEST_F(ResolverFunctionValidationTest, DiscardCalledDirectlyFromVertexEntryPoint) {
+ // @vertex() fn func() -> @position(0) vec4<f32> { discard; }
+ Func(Source{{1, 2}}, "func", utils::Empty, ty.vec4<f32>(),
+ utils::Vector{
+ Discard(Source{{12, 34}}),
+ },
+ utils::Vector{Stage(ast::PipelineStage::kVertex)},
+ utils::Vector{Builtin(ast::BuiltinValue::kPosition)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: discard statement cannot be used in vertex pipeline stage");
+}
+
+TEST_F(ResolverFunctionValidationTest, DiscardCalledIndirectlyFromComputeEntryPoint) {
+ // fn f0 { discard; }
+ // fn f1 { f0(); }
+ // fn f2 { f1(); }
+ // @compute @workgroup_size(1) fn main { return f2(); }
+
+ Func(Source{{1, 2}}, "f0", utils::Empty, ty.void_(),
+ utils::Vector{
+ Discard(Source{{12, 34}}),
+ });
+
+ Func(Source{{3, 4}}, "f1", utils::Empty, ty.void_(),
+ utils::Vector{
+ CallStmt(Call("f0")),
+ });
+
+ Func(Source{{5, 6}}, "f2", utils::Empty, ty.void_(),
+ utils::Vector{
+ CallStmt(Call("f1")),
+ });
+
+ Func(Source{{7, 8}}, "main", utils::Empty, ty.void_(),
+ utils::Vector{
+ CallStmt(Call("f2")),
+ },
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1_i),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: discard statement cannot be used in compute pipeline stage
+1:2 note: called by function 'f0'
+3:4 note: called by function 'f1'
+5:6 note: called by function 'f2'
+7:8 note: called by entry point 'main')");
+}
+
TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) {
// fn func() -> int { var a:i32 = 2i; }
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index b8ba0fe..eebd5ea 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -619,34 +619,50 @@
if (var->HasBindingPoint()) {
uint32_t binding = 0;
{
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
auto* attr = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
- auto* materialize = Materialize(Expression(attr->expr));
- if (!materialize) {
+ auto* materialized = Materialize(Expression(attr->expr));
+ if (!materialized) {
return nullptr;
}
- auto* c = materialize->ConstantValue();
- if (!c) {
- // TODO(crbug.com/tint/1633): Add error message about invalid materialization
- // when binding can be an expression.
+ if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
+ AddError("'binding' must be an i32 or u32 value", attr->source);
return nullptr;
}
- binding = c->As<uint32_t>();
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->As<AInt>();
+ if (value < 0) {
+ AddError("'binding' value must be non-negative", attr->source);
+ return nullptr;
+ }
+ binding = u32(value);
}
uint32_t group = 0;
{
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
auto* attr = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
- auto* materialize = Materialize(Expression(attr->expr));
- if (!materialize) {
+ auto* materialized = Materialize(Expression(attr->expr));
+ if (!materialized) {
return nullptr;
}
- auto* c = materialize->ConstantValue();
- if (!c) {
- // TODO(crbug.com/tint/1633): Add error message about invalid materialization
- // when binding can be an expression.
+ if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
+ AddError("'group' must be an i32 or u32 value", attr->source);
return nullptr;
}
- group = c->As<uint32_t>();
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->As<AInt>();
+ if (value < 0) {
+ AddError("'group' value must be non-negative", attr->source);
+ return nullptr;
+ }
+ group = u32(value);
}
binding_point = {group, binding};
}
@@ -2885,15 +2901,26 @@
if (!materialized) {
return nullptr;
}
+ if (!materialized->Type()->IsAnyOf<sem::U32, sem::I32>()) {
+ AddError("'size' must be an i32 or u32 value", s->source);
+ return nullptr;
+ }
+
auto const_value = materialized->ConstantValue();
if (!const_value) {
AddError("'size' must be constant expression", s->expr->source);
return nullptr;
}
+ {
+ auto value = const_value->As<AInt>();
+ if (value <= 0) {
+ AddError("'size' attribute must be positive", s->source);
+ return nullptr;
+ }
+ }
auto value = const_value->As<uint64_t>();
-
if (value < size) {
- AddError("size must be at least as big as the type's size (" +
+ AddError("'size' must be at least as big as the type's size (" +
std::to_string(size) + ")",
s->source);
return nullptr;
@@ -3218,7 +3245,7 @@
builder_->create<sem::Statement>(stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Behaviors() = sem::Behavior::kDiscard;
- current_function_->SetHasDiscard();
+ current_function_->SetDiscardStatement(sem);
return validator_.DiscardStatement(sem, current_statement_);
});
diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc
index 43f0bf1..5002857 100644
--- a/src/tint/resolver/resolver_behavior_test.cc
+++ b/src/tint/resolver/resolver_behavior_test.cc
@@ -43,7 +43,9 @@
TEST_F(ResolverBehaviorTest, ExprBinaryOp_LHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(Call("DiscardOrNext"), 1_i)));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -53,7 +55,9 @@
TEST_F(ResolverBehaviorTest, ExprBinaryOp_RHS) {
auto* stmt = Decl(Var("lhs", ty.i32(), Add(1_i, Call("DiscardOrNext"))));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -63,7 +67,9 @@
TEST_F(ResolverBehaviorTest, ExprBitcastOp) {
auto* stmt = Decl(Var("lhs", ty.u32(), Bitcast<u32>(Call("DiscardOrNext"))));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -79,7 +85,9 @@
});
auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1_i)));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -89,8 +97,13 @@
TEST_F(ResolverBehaviorTest, ExprIndex_Idx) {
auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor("arr", Call("DiscardOrNext"))));
- WrapInFunction(Decl(Var("arr", ty.array<i32, 4>())), //
- stmt);
+
+ Func("F", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("arr", ty.array<i32, 4>())), //
+ stmt,
+ },
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -102,7 +115,9 @@
auto* stmt =
Decl(Var("lhs", ty.i32(),
create<ast::UnaryOpExpression>(ast::UnaryOp::kComplement, Call("DiscardOrNext"))));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -124,8 +139,13 @@
TEST_F(ResolverBehaviorTest, StmtAssign_LHSDiscardOrNext) {
auto* stmt = Assign(IndexAccessor("lhs", Call("DiscardOrNext")), 1_i);
- WrapInFunction(Decl(Var("lhs", ty.array<i32, 4>())), //
- stmt);
+
+ Func("F", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("lhs", ty.array<i32, 4>())), //
+ stmt,
+ },
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -135,8 +155,13 @@
TEST_F(ResolverBehaviorTest, StmtAssign_RHSDiscardOrNext) {
auto* stmt = Assign("lhs", Call("DiscardOrNext"));
- WrapInFunction(Decl(Var("lhs", ty.i32())), //
- stmt);
+
+ Func("F", utils::Empty, ty.void_(),
+ utils::Vector{
+ Decl(Var("lhs", ty.i32())), //
+ stmt,
+ },
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -156,7 +181,9 @@
TEST_F(ResolverBehaviorTest, StmtBlockSingleStmt) {
auto* stmt = Block(Discard());
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -178,7 +205,9 @@
TEST_F(ResolverBehaviorTest, StmtCallFuncDiscard) {
Func("f", utils::Empty, ty.void_(), utils::Vector{Discard()});
auto* stmt = CallStmt(Call("f"));
- WrapInFunction(stmt);
+
+ Func("g", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -189,7 +218,9 @@
TEST_F(ResolverBehaviorTest, StmtCallFuncMayDiscard) {
auto* stmt =
For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -220,7 +251,9 @@
TEST_F(ResolverBehaviorTest, StmtDiscard) {
auto* stmt = Discard();
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -256,7 +289,9 @@
TEST_F(ResolverBehaviorTest, StmtForLoopDiscard) {
auto* stmt = For(nullptr, nullptr, nullptr, Block(Discard()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -277,7 +312,9 @@
TEST_F(ResolverBehaviorTest, StmtForLoopBreak_InitCallFuncMayDiscard) {
auto* stmt =
For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block(Break()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -287,7 +324,9 @@
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_InitCallFuncMayDiscard) {
auto* stmt = For(Decl(Var("v", ty.i32(), Call("DiscardOrNext"))), nullptr, nullptr, Block());
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -307,7 +346,9 @@
TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
auto* stmt = For(nullptr, Equal(Call("DiscardOrNext"), 1_i), nullptr, Block());
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -327,7 +368,9 @@
TEST_F(ResolverBehaviorTest, StmtWhileDiscard) {
auto* stmt = While(Expr(true), Block(Discard()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -357,7 +400,9 @@
TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondCallFuncMayDiscard) {
auto* stmt = While(Equal(Call("DiscardOrNext"), 1_i), Block());
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -377,7 +422,9 @@
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard) {
auto* stmt = If(true, Block(Discard()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -387,7 +434,9 @@
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseDiscard) {
auto* stmt = If(true, Block(), Else(Block(Discard())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -397,7 +446,9 @@
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenDiscard_ElseDiscard) {
auto* stmt = If(true, Block(Discard()), Else(Block(Discard())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -407,7 +458,9 @@
TEST_F(ResolverBehaviorTest, StmtIfCallFuncMayDiscard_ThenEmptyBlock) {
auto* stmt = If(Equal(Call("DiscardOrNext"), 1_i), Block());
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -418,7 +471,9 @@
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock_ElseCallFuncMayDiscard) {
auto* stmt = If(true, Block(), //
Else(If(Equal(Call("DiscardOrNext"), 1_i), Block())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -438,7 +493,9 @@
TEST_F(ResolverBehaviorTest, StmtLetDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Let("lhs", ty.i32(), Call("DiscardOrNext")));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -474,7 +531,9 @@
TEST_F(ResolverBehaviorTest, StmtLoopDiscard) {
auto* stmt = Loop(Block(Discard()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -522,6 +581,7 @@
TEST_F(ResolverBehaviorTest, StmtReturn_DiscardOrNext) {
auto* stmt = Return(Call("DiscardOrNext"));
+
Func("F", utils::Empty, ty.i32(), utils::Vector{stmt});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -552,7 +612,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_DefaultDiscard) {
auto* stmt = Switch(1_i, DefaultCase(Block(Discard())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -582,7 +644,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Discard())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -602,7 +666,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -612,7 +678,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Discard())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -622,7 +690,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Return())));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -635,7 +705,9 @@
Case(Expr(0_i), Block(Discard())), //
Case(Expr(1_i), Block(Return())), //
DefaultCase(Block()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -646,7 +718,9 @@
TEST_F(ResolverBehaviorTest, StmtSwitch_CondCallFuncMayDiscard_DefaultEmpty) {
auto* stmt = Switch(Call("DiscardOrNext"), DefaultCase(Block()));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -666,7 +740,9 @@
TEST_F(ResolverBehaviorTest, StmtVarDecl_RHSDiscardOrNext) {
auto* stmt = Decl(Var("lhs", ty.i32(), Call("DiscardOrNext")));
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index f4fa3cc..edce7eb 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -1266,11 +1266,11 @@
TEST_F(ResolverValidationTest, ZeroStructMemberSizeAttribute) {
Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, 0_a)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, 1_a)}),
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: size must be at least as big as the type's size (4)");
+ EXPECT_EQ(r()->error(), "12:34 error: 'size' must be at least as big as the type's size (4)");
}
TEST_F(ResolverValidationTest, OffsetAndSizeAttribute) {
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 3066315..e14cdd6 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1942,6 +1942,18 @@
}
bool Validator::PipelineStages(const std::vector<sem::Function*>& entry_points) const {
+ auto backtrace = [&](const sem::Function* func, const sem::Function* entry_point) {
+ if (func != entry_point) {
+ TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
+ AddNote("called by function '" + symbols_.NameFor(f->Declaration()->symbol) + "'",
+ f->Declaration()->source);
+ });
+ AddNote("called by entry point '" +
+ symbols_.NameFor(entry_point->Declaration()->symbol) + "'",
+ entry_point->Declaration()->source);
+ }
+ };
+
auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage();
@@ -1959,17 +1971,7 @@
}
}
AddNote("variable is declared here", var->Declaration()->source);
- if (func != entry_point) {
- TraverseCallChain(
- diagnostics_, entry_point, func, [&](const sem::Function* f) {
- AddNote("called by function '" +
- symbols_.NameFor(f->Declaration()->symbol) + "'",
- f->Declaration()->source);
- });
- AddNote("called by entry point '" +
- symbols_.NameFor(entry_point->Declaration()->symbol) + "'",
- entry_point->Declaration()->source);
- }
+ backtrace(func, entry_point);
return false;
}
}
@@ -1977,17 +1979,6 @@
return true;
};
- for (auto* entry_point : entry_points) {
- if (!check_workgroup_storage(entry_point, entry_point)) {
- return false;
- }
- for (auto* func : entry_point->TransitivelyCalledFunctions()) {
- if (!check_workgroup_storage(func, entry_point)) {
- return false;
- }
- }
- }
-
auto check_builtin_calls = [&](const sem::Function* func, const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage();
for (auto* builtin : func->DirectlyCalledBuiltins()) {
@@ -1997,16 +1988,34 @@
err << "built-in cannot be used by " << stage << " pipeline stage";
AddError(err.str(),
call ? call->Declaration()->source : func->Declaration()->source);
- if (func != entry_point) {
- TraverseCallChain(diagnostics_, entry_point, func, [&](const sem::Function* f) {
- AddNote("called by function '" +
- symbols_.NameFor(f->Declaration()->symbol) + "'",
- f->Declaration()->source);
- });
- AddNote("called by entry point '" +
- symbols_.NameFor(entry_point->Declaration()->symbol) + "'",
- entry_point->Declaration()->source);
- }
+ backtrace(func, entry_point);
+ return false;
+ }
+ }
+ return true;
+ };
+
+ auto check_no_discards = [&](const sem::Function* func, const sem::Function* entry_point) {
+ if (auto* discard = func->DiscardStatement()) {
+ auto stage = entry_point->Declaration()->PipelineStage();
+ std::stringstream err;
+ err << "discard statement cannot be used in " << stage << " pipeline stage";
+ AddError(err.str(), discard->Declaration()->source);
+ backtrace(func, entry_point);
+ return false;
+ }
+ return true;
+ };
+
+ auto check_func = [&](const sem::Function* func, const sem::Function* entry_point) {
+ if (!check_workgroup_storage(func, entry_point)) {
+ return false;
+ }
+ if (!check_builtin_calls(func, entry_point)) {
+ return false;
+ }
+ if (entry_point->Declaration()->PipelineStage() != ast::PipelineStage::kFragment) {
+ if (!check_no_discards(func, entry_point)) {
return false;
}
}
@@ -2014,15 +2023,16 @@
};
for (auto* entry_point : entry_points) {
- if (!check_builtin_calls(entry_point, entry_point)) {
+ if (!check_func(entry_point, entry_point)) {
return false;
}
for (auto* func : entry_point->TransitivelyCalledFunctions()) {
- if (!check_builtin_calls(func, entry_point)) {
+ if (!check_func(func, entry_point)) {
return false;
}
}
}
+
return true;
}
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index 9ef99ab..973aa43 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -237,12 +237,17 @@
/// @returns true if `sym` is an ancestor entry point of this function
bool HasAncestorEntryPoint(Symbol sym) const;
- /// Sets that this function has a discard statement
- void SetHasDiscard() { has_discard_ = true; }
+ /// Records the first discard statement in the function
+ /// @param stmt the `discard` statement.
+ void SetDiscardStatement(const Statement* stmt) {
+ if (!discard_stmt_) {
+ discard_stmt_ = stmt;
+ }
+ }
- /// Returns true if this function has a discard statement
- /// @returns true if this function has a discard statement
- bool HasDiscard() const { return has_discard_; }
+ /// @returns the first discard statement for the function, or nullptr if the function does not
+ /// use `discard`.
+ const Statement* DiscardStatement() const { return discard_stmt_; }
/// @return the behaviors of this function
const sem::Behaviors& Behaviors() const { return behaviors_; }
@@ -271,7 +276,7 @@
std::vector<const Call*> direct_calls_;
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
- bool has_discard_ = false;
+ const Statement* discard_stmt_ = nullptr;
sem::Behaviors behaviors_{sem::Behavior::kNext};
std::optional<uint32_t> return_location_;
diff --git a/src/tint/writer/check_supported_extensions.cc b/src/tint/writer/check_supported_extensions.cc
new file mode 100644
index 0000000..3bb16bf
--- /dev/null
+++ b/src/tint/writer/check_supported_extensions.cc
@@ -0,0 +1,48 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/check_supported_extensions.h"
+
+#include <string>
+
+#include "src/tint/ast/module.h"
+#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/utils/hashset.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::writer {
+
+bool CheckSupportedExtensions(std::string_view writer_name,
+ const ast::Module& module,
+ diag::List& diags,
+ utils::VectorRef<ast::Extension> supported) {
+ utils::Hashset<ast::Extension, 32> set;
+ for (auto ext : supported) {
+ set.Add(ext);
+ }
+
+ for (auto* enable : module.Enables()) {
+ auto ext = enable->extension;
+ if (!set.Contains(ext)) {
+ diags.add_error(diag::System::Writer,
+ std::string(writer_name) + " backend does not support extension '" +
+ utils::ToString(ext) + "'",
+ enable->source);
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace tint::writer
diff --git a/src/tint/writer/check_supported_extensions.h b/src/tint/writer/check_supported_extensions.h
new file mode 100644
index 0000000..c1884d0
--- /dev/null
+++ b/src/tint/writer/check_supported_extensions.h
@@ -0,0 +1,43 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_WRITER_CHECK_SUPPORTED_EXTENSIONS_H_
+#define SRC_TINT_WRITER_CHECK_SUPPORTED_EXTENSIONS_H_
+
+#include "src/tint/ast/extension.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::ast {
+class Module;
+} // namespace tint::ast
+namespace tint::diag {
+class List;
+} // namespace tint::diag
+
+namespace tint::writer {
+
+/// Checks that all the extensions enabled in @p module are found in @p supported, raising an error
+/// diagnostic if an enabled extension is not supported.
+/// @param writer_name the name of the writer making this call
+/// @param module the AST module
+/// @param diags the diagnostics to append an error to, if needed.
+/// @returns true if all extensions in use are supported, otherwise returns false.
+bool CheckSupportedExtensions(std::string_view writer_name,
+ const ast::Module& module,
+ diag::List& diags,
+ utils::VectorRef<ast::Extension> supported);
+
+} // namespace tint::writer
+
+#endif // SRC_TINT_WRITER_CHECK_SUPPORTED_EXTENSIONS_H_
diff --git a/src/tint/writer/check_supported_extensions_test.cc b/src/tint/writer/check_supported_extensions_test.cc
new file mode 100644
index 0000000..97c3b32
--- /dev/null
+++ b/src/tint/writer/check_supported_extensions_test.cc
@@ -0,0 +1,47 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/check_supported_extensions.h"
+
+#include "gtest/gtest.h"
+
+#include "src/tint/program_builder.h"
+
+namespace tint::writer {
+namespace {
+
+class CheckSupportedExtensionsTest : public ::testing::Test, public ProgramBuilder {};
+
+TEST_F(CheckSupportedExtensionsTest, Supported) {
+ Enable(ast::Extension::kF16);
+
+ ASSERT_TRUE(CheckSupportedExtensions("writer", AST(), Diagnostics(),
+ utils::Vector{
+ ast::Extension::kF16,
+ ast::Extension::kChromiumExperimentalDp4A,
+ }));
+}
+
+TEST_F(CheckSupportedExtensionsTest, Unsupported) {
+ Enable(Source{{12, 34}}, ast::Extension::kF16);
+
+ ASSERT_FALSE(CheckSupportedExtensions("writer", AST(), Diagnostics(),
+ utils::Vector{
+ ast::Extension::kChromiumExperimentalDp4A,
+ }));
+ EXPECT_EQ(Diagnostics().str(), "12:34 error: writer backend does not support extension 'f16'");
+}
+
+} // namespace
+} // namespace tint::writer
diff --git a/src/tint/writer/glsl/generator_impl_block_test.cc b/src/tint/writer/glsl/generator_impl_block_test.cc
index 014c1c7..bf377fb 100644
--- a/src/tint/writer/glsl/generator_impl_block_test.cc
+++ b/src/tint/writer/glsl/generator_impl_block_test.cc
@@ -20,7 +20,7 @@
using GlslGeneratorImplTest_Block = TestHelper;
TEST_F(GlslGeneratorImplTest_Block, Emit_Block) {
- auto* b = Block(create<ast::DiscardStatement>());
+ auto* b = Block(Return());
WrapInFunction(b);
GeneratorImpl& gen = Build();
@@ -29,7 +29,7 @@
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
- discard;
+ return;
}
)");
}
diff --git a/src/tint/writer/glsl/generator_impl_discard_test.cc b/src/tint/writer/glsl/generator_impl_discard_test.cc
index 87c85cb..8be1961 100644
--- a/src/tint/writer/glsl/generator_impl_discard_test.cc
+++ b/src/tint/writer/glsl/generator_impl_discard_test.cc
@@ -21,7 +21,9 @@
TEST_F(GlslGeneratorImplTest_Discard, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>();
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/glsl/generator_impl_loop_test.cc b/src/tint/writer/glsl/generator_impl_loop_test.cc
index 64074ef..e638eb3 100644
--- a/src/tint/writer/glsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/glsl/generator_impl_loop_test.cc
@@ -27,7 +27,8 @@
auto* continuing = Block();
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -47,7 +48,8 @@
auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -81,7 +83,9 @@
continuing = Block(Assign(lhs, rhs));
auto* outer = Loop(body, continuing);
- WrapInFunction(outer);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index a51b954..439f0ee 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -72,6 +72,7 @@
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/utils/string.h"
#include "src/tint/writer/append_vector.h"
+#include "src/tint/writer/check_supported_extensions.h"
#include "src/tint/writer/float_to_string.h"
#include "src/tint/writer/generate_external_texture_bindings.h"
@@ -254,6 +255,16 @@
GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() {
+ if (!CheckSupportedExtensions("HLSL", program_->AST(), diagnostics_,
+ utils::Vector{
+ ast::Extension::kChromiumDisableUniformityAnalysis,
+ ast::Extension::kChromiumExperimentalDp4A,
+ ast::Extension::kChromiumExperimentalPushConstant,
+ ast::Extension::kF16,
+ })) {
+ return false;
+ }
+
const TypeInfo* last_kind = nullptr;
size_t last_padding_line = 0;
@@ -2756,7 +2767,7 @@
out << ") {";
}
- if (sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>()) {
+ if (sem->DiscardStatement() && !sem->ReturnType()->Is<sem::Void>()) {
// BUG(crbug.com/tint/1081): work around non-void functions with discard
// failing compilation sometimes
if (!EmitFunctionBodyWithDiscard(func)) {
@@ -2780,7 +2791,7 @@
// there is always an (unused) return statement.
auto* sem = builder_.Sem().Get(func);
- TINT_ASSERT(Writer, sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>());
+ TINT_ASSERT(Writer, sem->DiscardStatement() && !sem->ReturnType()->Is<sem::Void>());
ScopedIndent si(this);
line() << "if (true) {";
diff --git a/src/tint/writer/hlsl/generator_impl_block_test.cc b/src/tint/writer/hlsl/generator_impl_block_test.cc
index 9a6cada..c3687da 100644
--- a/src/tint/writer/hlsl/generator_impl_block_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_block_test.cc
@@ -20,7 +20,7 @@
using HlslGeneratorImplTest_Block = TestHelper;
TEST_F(HlslGeneratorImplTest_Block, Emit_Block) {
- auto* b = Block(create<ast::DiscardStatement>());
+ auto* b = Block(Return());
WrapInFunction(b);
GeneratorImpl& gen = Build();
@@ -29,7 +29,7 @@
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
- discard;
+ return;
}
)");
}
diff --git a/src/tint/writer/hlsl/generator_impl_discard_test.cc b/src/tint/writer/hlsl/generator_impl_discard_test.cc
index 4bc4bf9..52bb958 100644
--- a/src/tint/writer/hlsl/generator_impl_discard_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_discard_test.cc
@@ -21,7 +21,9 @@
TEST_F(HlslGeneratorImplTest_Discard, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>();
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl_loop_test.cc b/src/tint/writer/hlsl/generator_impl_loop_test.cc
index 92f966c..238fdd2 100644
--- a/src/tint/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_loop_test.cc
@@ -27,7 +27,8 @@
auto* continuing = Block();
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -47,7 +48,8 @@
auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -81,7 +83,9 @@
continuing = Block(Assign(lhs, rhs));
auto* outer = Loop(body, continuing);
- WrapInFunction(outer);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl_test.cc b/src/tint/writer/hlsl/generator_impl_test.cc
index 35b9c82..8fa1156 100644
--- a/src/tint/writer/hlsl/generator_impl_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_test.cc
@@ -28,6 +28,15 @@
EXPECT_EQ(result.error, "input program is not valid");
}
+TEST_F(HlslGeneratorImplTest, UnsupportedExtension) {
+ Enable(Source{{12, 34}}, ast::Extension::kUndefined);
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_FALSE(gen.Generate());
+ EXPECT_EQ(gen.error(), R"(12:34 error: HLSL backend does not support extension 'undefined')");
+}
+
TEST_F(HlslGeneratorImplTest, Generate) {
Func("my_func", {}, ty.void_(), {});
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 27a5191..081444f 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -76,6 +76,7 @@
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
+#include "src/tint/writer/check_supported_extensions.h"
#include "src/tint/writer/float_to_string.h"
#include "src/tint/writer/generate_external_texture_bindings.h"
@@ -278,6 +279,15 @@
GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() {
+ if (!CheckSupportedExtensions("MSL", program_->AST(), diagnostics_,
+ utils::Vector{
+ ast::Extension::kChromiumDisableUniformityAnalysis,
+ ast::Extension::kChromiumExperimentalPushConstant,
+ ast::Extension::kF16,
+ })) {
+ return false;
+ }
+
line() << "#include <metal_stdlib>";
line();
line() << "using namespace metal;";
diff --git a/src/tint/writer/msl/generator_impl_block_test.cc b/src/tint/writer/msl/generator_impl_block_test.cc
index 9e73eac..5d4862c 100644
--- a/src/tint/writer/msl/generator_impl_block_test.cc
+++ b/src/tint/writer/msl/generator_impl_block_test.cc
@@ -20,7 +20,7 @@
using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, Emit_Block) {
- auto* b = Block(create<ast::DiscardStatement>());
+ auto* b = Block(Return());
WrapInFunction(b);
GeneratorImpl& gen = Build();
@@ -29,13 +29,13 @@
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
- discard_fragment();
+ return;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_Block_WithoutNewline) {
- auto* b = Block(create<ast::DiscardStatement>());
+ auto* b = Block(Return());
WrapInFunction(b);
GeneratorImpl& gen = Build();
@@ -44,7 +44,7 @@
ASSERT_TRUE(gen.EmitBlock(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
- discard_fragment();
+ return;
}
)");
}
diff --git a/src/tint/writer/msl/generator_impl_discard_test.cc b/src/tint/writer/msl/generator_impl_discard_test.cc
index 5f5c17f..8477129 100644
--- a/src/tint/writer/msl/generator_impl_discard_test.cc
+++ b/src/tint/writer/msl/generator_impl_discard_test.cc
@@ -21,7 +21,9 @@
TEST_F(MslGeneratorImplTest, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>();
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/msl/generator_impl_loop_test.cc b/src/tint/writer/msl/generator_impl_loop_test.cc
index 41017ff..274ee94 100644
--- a/src/tint/writer/msl/generator_impl_loop_test.cc
+++ b/src/tint/writer/msl/generator_impl_loop_test.cc
@@ -26,7 +26,9 @@
auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block();
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -45,7 +47,9 @@
auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -76,7 +80,9 @@
continuing = Block(Assign("lhs", "rhs"));
auto* outer = Loop(body, continuing);
- WrapInFunction(outer);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{outer},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/msl/generator_impl_test.cc b/src/tint/writer/msl/generator_impl_test.cc
index bff94b0..2299c19 100644
--- a/src/tint/writer/msl/generator_impl_test.cc
+++ b/src/tint/writer/msl/generator_impl_test.cc
@@ -31,6 +31,15 @@
EXPECT_EQ(result.error, "input program is not valid");
}
+TEST_F(MslGeneratorImplTest, UnsupportedExtension) {
+ Enable(Source{{12, 34}}, ast::Extension::kUndefined);
+
+ GeneratorImpl& gen = Build();
+
+ ASSERT_FALSE(gen.Generate());
+ EXPECT_EQ(gen.error(), R"(12:34 error: MSL backend does not support extension 'undefined')");
+}
+
TEST_F(MslGeneratorImplTest, Generate) {
Func("my_func", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 83587f1..96edf6c 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -47,6 +47,7 @@
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
#include "src/tint/writer/append_vector.h"
+#include "src/tint/writer/check_supported_extensions.h"
namespace tint::writer::spirv {
namespace {
@@ -259,6 +260,17 @@
Builder::~Builder() = default;
bool Builder::Build() {
+ if (!CheckSupportedExtensions("SPIR-V", builder_.AST(), builder_.Diagnostics(),
+ utils::Vector{
+ ast::Extension::kChromiumDisableUniformityAnalysis,
+ ast::Extension::kChromiumExperimentalDp4A,
+ ast::Extension::kChromiumExperimentalPushConstant,
+ ast::Extension::kF16,
+ })) {
+ error_ = builder_.Diagnostics().str();
+ return false;
+ }
+
push_capability(SpvCapabilityShader);
push_memory_model(spv::Op::OpMemoryModel,
diff --git a/src/tint/writer/spirv/builder_discard_test.cc b/src/tint/writer/spirv/builder_discard_test.cc
index 8c49abe..747b85c 100644
--- a/src/tint/writer/spirv/builder_discard_test.cc
+++ b/src/tint/writer/spirv/builder_discard_test.cc
@@ -21,13 +21,15 @@
using BuilderTest = TestHelper;
TEST_F(BuilderTest, Discard) {
- auto* expr = create<ast::DiscardStatement>();
- WrapInFunction(expr);
+ auto* stmt = Discard();
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build();
b.push_function(Function{});
- EXPECT_TRUE(b.GenerateStatement(expr)) << b.error();
+ EXPECT_TRUE(b.GenerateStatement(stmt)) << b.error();
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpKill
)");
}
diff --git a/src/tint/writer/spirv/builder_test.cc b/src/tint/writer/spirv/builder_test.cc
index 24d5b72..b3dce7e 100644
--- a/src/tint/writer/spirv/builder_test.cc
+++ b/src/tint/writer/spirv/builder_test.cc
@@ -29,6 +29,15 @@
EXPECT_EQ(result.error, "input program is not valid");
}
+TEST_F(BuilderTest, UnsupportedExtension) {
+ Enable(Source{{12, 34}}, ast::Extension::kUndefined);
+
+ auto program = std::make_unique<Program>(std::move(*this));
+ auto result = Generate(program.get(), Options{});
+ EXPECT_EQ(result.error,
+ R"(12:34 error: SPIR-V backend does not support extension 'undefined')");
+}
+
TEST_F(BuilderTest, TracksIdBounds) {
spirv::Builder& b = Build();
diff --git a/src/tint/writer/wgsl/generator_impl_block_test.cc b/src/tint/writer/wgsl/generator_impl_block_test.cc
index ae01200..0609f97 100644
--- a/src/tint/writer/wgsl/generator_impl_block_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_block_test.cc
@@ -20,7 +20,7 @@
using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Block) {
- auto* b = Block(create<ast::DiscardStatement>());
+ auto* b = Block(Return());
WrapInFunction(b);
GeneratorImpl& gen = Build();
@@ -29,7 +29,7 @@
ASSERT_TRUE(gen.EmitStatement(b)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
- discard;
+ return;
}
)");
}
diff --git a/src/tint/writer/wgsl/generator_impl_discard_test.cc b/src/tint/writer/wgsl/generator_impl_discard_test.cc
index db176e9..07cf399 100644
--- a/src/tint/writer/wgsl/generator_impl_discard_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_discard_test.cc
@@ -21,7 +21,9 @@
TEST_F(WgslGeneratorImplTest, Emit_Discard) {
auto* stmt = create<ast::DiscardStatement>();
- WrapInFunction(stmt);
+
+ Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/wgsl/generator_impl_loop_test.cc b/src/tint/writer/wgsl/generator_impl_loop_test.cc
index f4b6898..bf0eef6 100644
--- a/src/tint/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_loop_test.cc
@@ -26,7 +26,8 @@
auto* continuing = Block();
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -46,7 +47,8 @@
auto* continuing = Block(CallStmt(Call("a_statement")));
auto* l = Loop(body, continuing);
- WrapInFunction(l);
+ Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+ utils::Vector{Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();