resolver: Allocate IDs for named pipeline constants
Keep track of any constant IDs specified in the shader, and then
allocate IDs for the remaining constants when creating the semantic
info.
Bug: tint:755
Change-Id: I6a76b1193cac459b62582cde7469b092dde51d5d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50841
Commit-Queue: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 2bd1b09..6e5f3ab 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -550,6 +550,7 @@
resolver/intrinsic_test.cc
resolver/is_host_shareable_test.cc
resolver/is_storeable_test.cc
+ resolver/pipeline_overridable_constant_test.cc
resolver/resolver_test_helper.cc
resolver/resolver_test_helper.h
resolver/resolver_test.cc
diff --git a/src/program_builder.h b/src/program_builder.h
index 7c8d5ba..833cb33 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -38,6 +38,7 @@
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
#include "src/ast/multisampled_texture.h"
+#include "src/ast/override_decoration.h"
#include "src/ast/pointer.h"
#include "src/ast/return_statement.h"
#include "src/ast/sampled_texture.h"
@@ -1786,6 +1787,32 @@
return create<ast::LocationDecoration>(source_, location);
}
+ /// Creates an ast::OverrideDecoration with a specific constant ID
+ /// @param source the source information
+ /// @param id the id value
+ /// @returns the override decoration pointer
+ ast::OverrideDecoration* Override(const Source& source, uint32_t id) {
+ return create<ast::OverrideDecoration>(source, id);
+ }
+
+ /// Creates an ast::OverrideDecoration with a specific constant ID
+ /// @param id the optional id value
+ /// @returns the override decoration pointer
+ ast::OverrideDecoration* Override(uint32_t id) {
+ return Override(source_, id);
+ }
+
+ /// Creates an ast::OverrideDecoration without a constant ID
+ /// @param source the source information
+ /// @returns the override decoration pointer
+ ast::OverrideDecoration* Override(const Source& source) {
+ return create<ast::OverrideDecoration>(source);
+ }
+
+ /// Creates an ast::OverrideDecoration without a constant ID
+ /// @returns the override decoration pointer
+ ast::OverrideDecoration* Override() { return Override(source_); }
+
/// Creates an ast::StageDecoration
/// @param source the source information
/// @param stage the pipeline stage
diff --git a/src/resolver/pipeline_overridable_constant_test.cc b/src/resolver/pipeline_overridable_constant_test.cc
new file mode 100644
index 0000000..48a3456
--- /dev/null
+++ b/src/resolver/pipeline_overridable_constant_test.cc
@@ -0,0 +1,113 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/resolver/resolver_test_helper.h"
+
+using ::testing::UnorderedElementsAre;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverPipelineOverridableConstantTest = ResolverTest;
+
+TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
+ auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_a = Sem().Get(a);
+ ASSERT_NE(sem_a, nullptr);
+ EXPECT_EQ(sem_a->Declaration(), a);
+ EXPECT_FALSE(sem_a->IsPipelineConstant());
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
+ auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override(7u)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_a = Sem().Get(a);
+ ASSERT_NE(sem_a, nullptr);
+ EXPECT_EQ(sem_a->Declaration(), a);
+ EXPECT_TRUE(sem_a->IsPipelineConstant());
+ EXPECT_EQ(sem_a->ConstantId(), 7u);
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
+ auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_a = Sem().Get(a);
+ ASSERT_NE(sem_a, nullptr);
+ EXPECT_EQ(sem_a->Declaration(), a);
+ EXPECT_TRUE(sem_a->IsPipelineConstant());
+ EXPECT_EQ(sem_a->ConstantId(), 0u);
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
+ std::vector<ast::Variable*> variables;
+ variables.push_back(
+ GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()}));
+ variables.push_back(
+ GlobalConst("b", ty.f32(), Construct(ty.f32()), {Override()}));
+ variables.push_back(
+ GlobalConst("c", ty.f32(), Construct(ty.f32()), {Override(2u)}));
+ variables.push_back(
+ GlobalConst("d", ty.f32(), Construct(ty.f32()), {Override(4u)}));
+ variables.push_back(
+ GlobalConst("e", ty.f32(), Construct(ty.f32()), {Override()}));
+ variables.push_back(
+ GlobalConst("f", ty.f32(), Construct(ty.f32()), {Override(1u)}));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ std::vector<uint16_t> constant_ids;
+ for (auto* var : variables) {
+ auto* sem = Sem().Get(var);
+ ASSERT_NE(sem, nullptr);
+ constant_ids.push_back(sem->ConstantId());
+ }
+ EXPECT_THAT(constant_ids, UnorderedElementsAre(0u, 3u, 2u, 4u, 5u, 1u));
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
+ GlobalConst("a", ty.f32(), Construct(ty.f32()),
+ {Override(Source{{12, 34}}, 7u)});
+ GlobalConst("b", ty.f32(), Construct(ty.f32()),
+ {Override(Source{{56, 78}}, 7u)});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(), R"(56:78 error: pipeline constant IDs must be unique
+12:34 note: a pipeline constant with an ID of 7 was previously declared here:)");
+}
+
+TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
+ GlobalConst("a", ty.f32(), Construct(ty.f32()),
+ {Override(Source{{12, 34}}, 65536u)});
+
+ EXPECT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(r()->error(),
+ "12:34 error: pipeline constant IDs must be between 0 and 65535");
+}
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index b38c2d1..328cb15 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -519,6 +519,13 @@
for (auto* deco : var->decorations()) {
Mark(deco);
+
+ if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
+ // Track the constant IDs that are specified in the shader.
+ if (override_deco->HasValue()) {
+ constant_ids_.emplace(override_deco->value(), info);
+ }
+ }
}
if (auto bp = var->binding_point()) {
@@ -543,7 +550,30 @@
bool Resolver::ValidateGlobalVariable(const VariableInfo* info) {
for (auto* deco : info->declaration->decorations()) {
if (info->declaration->is_const()) {
- if (!deco->Is<ast::OverrideDecoration>()) {
+ if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
+ if (override_deco->HasValue()) {
+ uint32_t id = override_deco->value();
+ auto itr = constant_ids_.find(id);
+ if (itr != constant_ids_.end() && itr->second != info) {
+ diagnostics_.add_error("pipeline constant IDs must be unique",
+ deco->source());
+ diagnostics_.add_note("a pipeline constant with an ID of " +
+ std::to_string(id) +
+ " was previously declared "
+ "here:",
+ ast::GetDecoration<ast::OverrideDecoration>(
+ itr->second->declaration->decorations())
+ ->source());
+ return false;
+ }
+ if (id > 65535) {
+ diagnostics_.add_error(
+ "pipeline constant IDs must be between 0 and 65535",
+ deco->source());
+ return false;
+ }
+ }
+ } else {
diagnostics_.add_error("decoration is not valid for constants",
deco->source());
return false;
@@ -2244,12 +2274,42 @@
}
}
+ // The next pipeline constant ID to try to allocate.
+ uint16_t next_constant_id = 0;
+
// Create semantic nodes for all ast::Variables
for (auto it : variable_to_info_) {
auto* var = it.first;
auto* info = it.second;
- auto* sem_var =
- builder_->create<sem::Variable>(var, info->type, info->storage_class);
+
+ sem::Variable* sem_var = nullptr;
+
+ if (auto* override_deco =
+ ast::GetDecoration<ast::OverrideDecoration>(var->decorations())) {
+ // Create a pipeline overridable constant.
+ uint16_t constant_id;
+ if (override_deco->HasValue()) {
+ constant_id = override_deco->value();
+ } else {
+ // No ID was specified, so allocate the next available ID.
+ constant_id = next_constant_id;
+ while (constant_ids_.count(constant_id)) {
+ if (constant_id == UINT16_MAX) {
+ TINT_ICE(builder_->Diagnostics())
+ << "no more pipeline constant IDs available";
+ return;
+ }
+ constant_id++;
+ }
+ next_constant_id = constant_id + 1;
+ }
+
+ sem_var = builder_->create<sem::Variable>(var, info->type, constant_id);
+ } else {
+ sem_var =
+ builder_->create<sem::Variable>(var, info->type, info->storage_class);
+ }
+
std::vector<const sem::VariableUser*> users;
for (auto* user : info->users) {
// Create semantic node for the identifier expression if necessary
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 95a7365..1f367f1 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -347,6 +347,7 @@
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<Symbol, sem::Type*> named_types_;
std::unordered_set<const ast::Node*> marked_;
+ std::unordered_map<uint32_t, const VariableInfo*> constant_ids_;
FunctionInfo* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;
diff --git a/src/sem/variable.cc b/src/sem/variable.cc
index 363785e..83210cc 100644
--- a/src/sem/variable.cc
+++ b/src/sem/variable.cc
@@ -26,7 +26,19 @@
Variable::Variable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class)
- : declaration_(declaration), type_(type), storage_class_(storage_class) {}
+ : declaration_(declaration),
+ type_(type),
+ storage_class_(storage_class),
+ is_pipeline_constant_(false) {}
+
+Variable::Variable(const ast::Variable* declaration,
+ const sem::Type* type,
+ uint16_t constant_id)
+ : declaration_(declaration),
+ type_(type),
+ storage_class_(ast::StorageClass::kNone),
+ is_pipeline_constant_(true),
+ constant_id_(constant_id) {}
Variable::~Variable() = default;
diff --git a/src/sem/variable.h b/src/sem/variable.h
index 2090ee2..ba6118f 100644
--- a/src/sem/variable.h
+++ b/src/sem/variable.h
@@ -37,7 +37,7 @@
/// Variable holds the semantic information for variables.
class Variable : public Castable<Variable, Node> {
public:
- /// Constructor
+ /// Constructor for variables and non-overridable constants
/// @param declaration the AST declaration node
/// @param type the variable type
/// @param storage_class the variable storage class
@@ -45,6 +45,14 @@
const sem::Type* type,
ast::StorageClass storage_class);
+ /// Constructor for overridable pipeline constants
+ /// @param declaration the AST declaration node
+ /// @param type the variable type
+ /// @param constant_id the pipeline constant ID
+ Variable(const ast::Variable* declaration,
+ const sem::Type* type,
+ uint16_t constant_id);
+
/// Destructor
~Variable() override;
@@ -63,11 +71,19 @@
/// @param user the user to add
void AddUser(const VariableUser* user) { users_.emplace_back(user); }
+ /// @returns true if this variable is an overridable pipeline constant
+ bool IsPipelineConstant() const { return is_pipeline_constant_; }
+
+ /// @returns the pipeline constant ID associated with the variable
+ uint32_t ConstantId() const { return constant_id_; }
+
private:
const ast::Variable* const declaration_;
const sem::Type* const type_;
ast::StorageClass const storage_class_;
std::vector<const VariableUser*> users_;
+ const bool is_pipeline_constant_;
+ const uint16_t constant_id_ = 0;
};
/// VariableUser holds the semantic information for an identifier expression
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 4f6bf11..f5c3440 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -257,6 +257,7 @@
"../src/resolver/intrinsic_test.cc",
"../src/resolver/is_host_shareable_test.cc",
"../src/resolver/is_storeable_test.cc",
+ "../src/resolver/pipeline_overridable_constant_test.cc",
"../src/resolver/resolver_test.cc",
"../src/resolver/resolver_test_helper.cc",
"../src/resolver/resolver_test_helper.h",