tint/resolver: Resolve static_assert
No readers produce this, yet.
Bug: tint:1625
Change-Id: I94ce3e5afd7bd81b0a5059451136aa0eed7e9283
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97961
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 037793a..a51bd22 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1126,6 +1126,7 @@
"resolver/resolver_test_helper.cc",
"resolver/resolver_test_helper.h",
"resolver/side_effects_test.cc",
+ "resolver/static_assert_test.cc",
"resolver/source_variable_test.cc",
"resolver/storage_class_layout_validation_test.cc",
"resolver/storage_class_validation_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 5f93189..9a6a42d 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -809,6 +809,7 @@
resolver/resolver_test_helper.h
resolver/resolver_test.cc
resolver/side_effects_test.cc
+ resolver/static_assert_test.cc
resolver/source_variable_test.cc
resolver/storage_class_layout_validation_test.cc
resolver/storage_class_validation_test.cc
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 36bff3f..2ddcaaa 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -204,6 +204,7 @@
[&](const ast::Enable*) {
// Enable directives do not effect the dependency graph.
},
+ [&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) { UnhandledNode(diagnostics_, global->node); });
}
@@ -315,6 +316,7 @@
TraverseExpression(w->condition);
TraverseStatement(w->body);
},
+ [&](const ast::StaticAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) {
if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) {
@@ -515,6 +517,8 @@
[&](const ast::TypeDecl* td) { return td->name; },
[&](const ast::Function* func) { return func->symbol; },
[&](const ast::Variable* var) { return var->symbol; },
+ [&](const ast::Enable*) { return Symbol(); },
+ [&](const ast::StaticAssert*) { return Symbol(); },
[&](Default) {
UnhandledNode(diagnostics_, node);
return Symbol{};
@@ -533,11 +537,12 @@
/// declaration
std::string KindOf(const ast::Node* node) {
return Switch(
- node, //
- [&](const ast::Struct*) { return "struct"; }, //
- [&](const ast::Alias*) { return "alias"; }, //
- [&](const ast::Function*) { return "function"; }, //
- [&](const ast::Variable* v) { return v->Kind(); }, //
+ node, //
+ [&](const ast::Struct*) { return "struct"; }, //
+ [&](const ast::Alias*) { return "alias"; }, //
+ [&](const ast::Function*) { return "function"; }, //
+ [&](const ast::Variable* v) { return v->Kind(); }, //
+ [&](const ast::StaticAssert*) { return "static_assert"; }, //
[&](Default) {
UnhandledNode(diagnostics_, node);
return "<error>";
@@ -549,9 +554,8 @@
void GatherGlobals(const ast::Module& module) {
for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node);
- // Enable directives do not form a symbol. Skip them.
- if (!node->Is<ast::Enable>()) {
- globals_.emplace(SymbolOf(node), global);
+ if (auto symbol = SymbolOf(node); symbol.IsValid()) {
+ globals_.emplace(symbol, global);
}
declaration_order_.emplace_back(global);
}
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 64c3e03..4295bab 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -140,6 +140,7 @@
[&](const ast::TypeDecl* td) { return TypeDecl(td); },
[&](const ast::Function* func) { return Function(func); },
[&](const ast::Variable* var) { return GlobalVariable(var); },
+ [&](const ast::StaticAssert* sa) { return StaticAssert(sa); },
[&](Default) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled global declaration: " << decl->TypeInfo().name;
@@ -737,6 +738,33 @@
return sem;
}
+sem::Statement* Resolver::StaticAssert(const ast::StaticAssert* assertion) {
+ auto* expr = Expression(assertion->condition);
+ if (!expr) {
+ return nullptr;
+ }
+ auto* cond = expr->ConstantValue();
+ if (!cond) {
+ AddError("static assertion condition must be a constant expression",
+ assertion->condition->source);
+ return nullptr;
+ }
+ if (auto* ty = cond->Type(); !ty->Is<sem::Bool>()) {
+ AddError(
+ "static assertion condition must be a bool, got '" + builder_->FriendlyName(ty) + "'",
+ assertion->condition->source);
+ return nullptr;
+ }
+ if (!cond->As<bool>()) {
+ AddError("static assertion failed", assertion->source);
+ return nullptr;
+ }
+ auto* sem =
+ builder_->create<sem::Statement>(assertion, current_compound_statement_, current_function_);
+ builder_->Sem().Add(assertion, sem);
+ return sem;
+}
+
sem::Function* Resolver::Function(const ast::Function* decl) {
uint32_t parameter_index = 0;
std::unordered_map<Symbol, Source> parameter_names;
@@ -1042,6 +1070,7 @@
[&](const ast::IncrementDecrementStatement* i) { return IncrementDecrementStatement(i); },
[&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
[&](const ast::VariableDeclStatement* v) { return VariableDeclStatement(v); },
+ [&](const ast::StaticAssert* sa) { return StaticAssert(sa); },
// Error cases
[&](const ast::CaseStatement*) {
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 10b5806..2b6a240 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -250,6 +250,7 @@
sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*);
+ sem::Statement* StaticAssert(const ast::StaticAssert*);
sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(utils::VectorRef<const ast::Statement*>);
diff --git a/src/tint/resolver/static_assert_test.cc b/src/tint/resolver/static_assert_test.cc
new file mode 100644
index 0000000..3cb67c9
--- /dev/null
+++ b/src/tint/resolver/static_assert_test.cc
@@ -0,0 +1,110 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/tint/resolver/resolver_test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::resolver {
+namespace {
+
+using ResolverStaticAssertTest = ResolverTest;
+
+TEST_F(ResolverStaticAssertTest, Global_True_Pass) {
+ GlobalStaticAssert(true);
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStaticAssertTest, Global_False_Fail) {
+ GlobalStaticAssert(Source{{12, 34}}, false);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+TEST_F(ResolverStaticAssertTest, Global_Const_Pass) {
+ GlobalConst("C", ty.bool_(), Expr(true));
+ GlobalStaticAssert("C");
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStaticAssertTest, Global_Const_Fail) {
+ GlobalConst("C", ty.bool_(), Expr(false));
+ GlobalStaticAssert(Source{{12, 34}}, "C");
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
+TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Pass) {
+ GlobalStaticAssert(LessThan(2_i, 3_i));
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
+TEST_F(ResolverStaticAssertTest, DISABLED_Global_LessThan_Fail) {
+ GlobalStaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+TEST_F(ResolverStaticAssertTest, Local_True_Pass) {
+ WrapInFunction(StaticAssert(true));
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStaticAssertTest, Local_False_Fail) {
+ WrapInFunction(StaticAssert(Source{{12, 34}}, false));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+TEST_F(ResolverStaticAssertTest, Local_Const_Pass) {
+ GlobalConst("C", ty.bool_(), Expr(true));
+ WrapInFunction(StaticAssert("C"));
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverStaticAssertTest, Local_Const_Fail) {
+ GlobalConst("C", ty.bool_(), Expr(false));
+ WrapInFunction(StaticAssert(Source{{12, 34}}, "C"));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+TEST_F(ResolverStaticAssertTest, Local_NonConst) {
+ GlobalVar("V", ty.bool_(), Expr(true), ast::StorageClass::kPrivate);
+ WrapInFunction(StaticAssert(Expr(Source{{12, 34}}, "V")));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: static assertion condition must be a constant expression");
+}
+
+// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
+TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Pass) {
+ WrapInFunction(StaticAssert(LessThan(2_i, 3_i)));
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+// TODO(crbug.com/tint/1581): Enable once the '<' operator is implemented for constant evaluation.
+TEST_F(ResolverStaticAssertTest, DISABLED_Local_LessThan_Fail) {
+ WrapInFunction(StaticAssert(Source{{12, 34}}, LessThan(4_i, 3_i)));
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: static assertion failed");
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 555ea1a..9ccb7ef 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -847,6 +847,7 @@
return cfx;
}
},
+
[&](const ast::ReturnStatement* r) {
Node* cf_ret;
if (r->value) {
@@ -870,6 +871,7 @@
return cf_ret;
},
+
[&](const ast::SwitchStatement* s) {
auto* sem_switch = sem_.Get(s);
auto [cfx, v_cond] = ProcessExpression(cf, s->condition);
@@ -938,6 +940,7 @@
return cf_end ? cf_end : cf;
},
+
[&](const ast::VariableDeclStatement* decl) {
Node* node;
if (decl->variable->constructor) {
@@ -956,6 +959,11 @@
return cf;
},
+
+ [&](const ast::StaticAssert*) {
+ return cf; // No impact on uniformity
+ },
+
[&](Default) {
TINT_ICE(Resolver, diagnostics_)
<< "unknown statement type: " << std::string(stmt->TypeInfo().name);