resolver: Fix constant propagation for POC
Pipeline overidable constants are not compile-time constant.
If a module-scope const has an [[override]] decoration, do not assign
the constant value to it, as this will propagate, and the constant value
may become inlined in places that should be overridable.
Also: Rename sem::GlobalVariable::IsPipelineConstant() to
IsOverridable() to make it clearer that this is not a compile-time known
value. Add SetIsOverridable() so we can correctly set the
IsOverridable() flag even when there isn't an ID.
Change-Id: I5ede9dd180d5ff1696b3868ea4313fc28f93af4b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69140
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index 171dbc8..5fead48 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -193,7 +193,7 @@
auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>();
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
OverridableConstant overridable_constant;
overridable_constant.name = name;
overridable_constant.numeric_id = global->ConstantId();
@@ -245,7 +245,7 @@
std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
- if (!global || !global->IsPipelineConstant()) {
+ if (!global || !global->IsOverridable()) {
continue;
}
@@ -300,7 +300,7 @@
std::map<std::string, uint32_t> result;
for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
auto name = program_->Symbols().NameFor(var->symbol);
result[name] = global->ConstantId();
}
diff --git a/src/resolver/pipeline_overridable_constant_test.cc b/src/resolver/pipeline_overridable_constant_test.cc
index 1ff83dc..074d561 100644
--- a/src/resolver/pipeline_overridable_constant_test.cc
+++ b/src/resolver/pipeline_overridable_constant_test.cc
@@ -30,24 +30,26 @@
auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var);
- EXPECT_TRUE(sem->IsPipelineConstant());
+ EXPECT_TRUE(sem->IsOverridable());
EXPECT_EQ(sem->ConstantId(), id);
+ EXPECT_FALSE(sem->ConstantValue());
}
};
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
- auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()));
+ auto* a = GlobalConst("a", ty.f32(), Expr(1.f));
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
ASSERT_NE(sem_a, nullptr);
EXPECT_EQ(sem_a->Declaration(), a);
- EXPECT_FALSE(sem_a->IsPipelineConstant());
+ EXPECT_FALSE(sem_a->IsOverridable());
+ EXPECT_TRUE(sem_a->ConstantValue());
}
TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
- auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override(7u)});
+ auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override(7u)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -55,7 +57,7 @@
}
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
- auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
+ auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -64,12 +66,12 @@
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
std::vector<ast::Variable*> variables;
- auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
- auto* b = GlobalConst("b", ty.f32(), Construct(ty.f32()), {Override()});
- auto* c = GlobalConst("c", ty.f32(), Construct(ty.f32()), {Override(2u)});
- auto* d = GlobalConst("d", ty.f32(), Construct(ty.f32()), {Override(4u)});
- auto* e = GlobalConst("e", ty.f32(), Construct(ty.f32()), {Override()});
- auto* f = GlobalConst("f", ty.f32(), Construct(ty.f32()), {Override(1u)});
+ auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
+ auto* b = GlobalConst("b", ty.f32(), Expr(1.f), {Override()});
+ auto* c = GlobalConst("c", ty.f32(), Expr(1.f), {Override(2u)});
+ auto* d = GlobalConst("d", ty.f32(), Expr(1.f), {Override(4u)});
+ auto* e = GlobalConst("e", ty.f32(), Expr(1.f), {Override()});
+ auto* f = GlobalConst("f", ty.f32(), Expr(1.f), {Override(1u)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -83,10 +85,8 @@
}
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
- GlobalConst("a", ty.f32(), Construct(ty.f32()),
- {Override(Source{{12, 34}}, 7u)});
- GlobalConst("b", ty.f32(), Construct(ty.f32()),
- {Override(Source{{56, 78}}, 7u)});
+ GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 7u)});
+ GlobalConst("b", ty.f32(), Expr(1.f), {Override(Source{{56, 78}}, 7u)});
EXPECT_FALSE(r()->Resolve());
@@ -95,8 +95,7 @@
}
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
- GlobalConst("a", ty.f32(), Construct(ty.f32()),
- {Override(Source{{12, 34}}, 65536u)});
+ GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 65536u)});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 1f3bcea..7a30a4d 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -547,13 +547,17 @@
binding_point = {bp.group->value, bp.binding->value};
}
+ auto* override =
+ ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
+ bool has_const_val = rhs && var->is_const && !override;
+
auto* global = builder_->create<sem::GlobalVariable>(
var, var_ty, storage_class, access,
- (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{},
+ has_const_val ? rhs->ConstantValue() : sem::Constant{},
binding_point);
- if (auto* override =
- ast::GetDecoration<ast::OverrideDecoration>(var->decorations)) {
+ if (override) {
+ global->SetIsOverridable();
if (override->has_value) {
global->SetConstantId(static_cast<uint16_t>(override->value));
}
@@ -1995,27 +1999,30 @@
return false;
}
- if (auto* ident = expr->As<ast::IdentifierExpression>()) {
- // We have an identifier of a module-scope constant.
- auto* var = variable_stack_.Get(ident->symbol);
- if (!var || !var->Declaration()->is_const) {
+ sem::Constant value;
+
+ if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
+ // We have an variable of a module-scope constant.
+ auto* decl = user->Variable()->Declaration();
+ if (!decl->is_const) {
AddError(kErrBadType, expr->source);
return false;
}
-
- auto* decl = var->Declaration();
// Capture the constant if an [[override]] attribute is present.
if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
ws[i].overridable_const = decl;
}
- expr = decl->constructor;
- if (!expr) {
+ if (decl->constructor) {
+ value = Sem(decl->constructor)->ConstantValue();
+ } else {
// No constructor means this value must be overriden by the user.
ws[i].value = 0;
continue;
}
- } else if (!expr->Is<ast::LiteralExpression>()) {
+ } else if (expr->Is<ast::LiteralExpression>()) {
+ value = Sem(expr)->ConstantValue();
+ } else {
AddError(
"workgroup_size argument must be either a literal or a "
"module-scope constant",
@@ -2023,20 +2030,19 @@
return false;
}
- auto val = expr_sem->ConstantValue();
- if (!val) {
+ if (!value) {
TINT_ICE(Resolver, diagnostics_)
<< "could not resolve constant workgroup_size constant value";
continue;
}
// Validate and set the default value for this dimension.
- if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
+ if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source);
return false;
}
- ws[i].value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
- : val.Elements()[0].u32;
+ ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
+ : value.Elements()[0].u32;
}
return true;
}
diff --git a/src/sem/variable.cc b/src/sem/variable.cc
index dd242be..8b5fdd9 100644
--- a/src/sem/variable.cc
+++ b/src/sem/variable.cc
@@ -61,8 +61,7 @@
Constant constant_value,
sem::BindingPoint binding_point)
: Base(declaration, type, storage_class, access, std::move(constant_value)),
- binding_point_(binding_point),
- is_pipeline_constant_(false) {}
+ binding_point_(binding_point) {}
GlobalVariable::~GlobalVariable() = default;
diff --git a/src/sem/variable.h b/src/sem/variable.h
index 5d4ac05..ffc89eb 100644
--- a/src/sem/variable.h
+++ b/src/sem/variable.h
@@ -129,22 +129,27 @@
/// @returns the resource binding point for the variable
sem::BindingPoint BindingPoint() const { return binding_point_; }
- /// @returns the pipeline constant ID associated with the variable
- uint16_t ConstantId() const { return constant_id_; }
-
/// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) {
constant_id_ = id;
- is_pipeline_constant_ = true;
+ is_overridable_ = true;
}
- /// @returns true if this variable is an overridable pipeline constant
- bool IsPipelineConstant() const { return is_pipeline_constant_; }
+ /// @returns the pipeline constant ID associated with the variable
+ uint16_t ConstantId() const { return constant_id_; }
+
+ /// @param is_overridable true if this is a pipeline overridable constant
+ void SetIsOverridable(bool is_overridable = true) {
+ is_overridable_ = is_overridable;
+ }
+
+ /// @returns true if this is pipeline overridable constant
+ bool IsOverridable() const { return is_overridable_; }
private:
const sem::BindingPoint binding_point_;
- bool is_pipeline_constant_ = false;
+ bool is_overridable_ = false;
uint16_t constant_id_ = 0;
};
diff --git a/src/writer/glsl/generator_impl.cc b/src/writer/glsl/generator_impl.cc
index 7825126..df24aaa 100644
--- a/src/writer/glsl/generator_impl.cc
+++ b/src/writer/glsl/generator_impl.cc
@@ -1835,7 +1835,7 @@
if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const);
- if (!global->IsPipelineConstant()) {
+ if (!global->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
@@ -2657,7 +2657,7 @@
auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>();
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id;
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0bc975f..0751248 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -2620,7 +2620,7 @@
if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const);
- if (!global->IsPipelineConstant()) {
+ if (!global->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
@@ -3407,7 +3407,7 @@
auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>();
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 77c87bd..94e011d 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -2651,7 +2651,7 @@
}
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
out << " [[function_constant(" << global->ConstantId() << ")]]";
} else if (var->constructor != nullptr) {
out << " = ";
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index e6a1365..697509d 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -524,7 +524,7 @@
// Make the constant specializable.
auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const);
- if (!sem_const->IsPipelineConstant()) {
+ if (!sem_const->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
@@ -1333,7 +1333,7 @@
// Generate the zero initializer if there are no values provided.
if (values.empty()) {
- if (global_var && global_var->IsPipelineConstant()) {
+ if (global_var && global_var->IsOverridable()) {
auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded(
@@ -1669,7 +1669,7 @@
ScalarConstant constant;
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
- if (global && global->IsPipelineConstant()) {
+ if (global && global->IsOverridable()) {
constant.is_spec_op = true;
constant.constant_id = global->ConstantId();
}