Import Tint changes from Dawn
Changes:
- 490d9889a79e39e1d1a91e01df68926f704a8089 tint: Simplify workgroup size resolving by Ben Clayton <bclayton@google.com>
- 1662f5578e46e7b47e3289b2764126c32a70e22d Fixup some grammar. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 490d9889a79e39e1d1a91e01df68926f704a8089
Change-Id: I6644d3475df834ebc08402993379beac2228b758
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/103400
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index 087e786..d32ef6f 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -148,9 +148,9 @@
entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize();
- if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
- !wgsize[2].overridable_const) {
- entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value};
+ if (wgsize[0].has_value() && wgsize[1].has_value() && wgsize[2].has_value()) {
+ entry_point.workgroup_size = {wgsize[0].value(), wgsize[1].value(),
+ wgsize[2].value()};
}
break;
}
@@ -849,19 +849,18 @@
auto* t = c->args[static_cast<size_t>(texture_index)];
auto* s = c->args[static_cast<size_t>(sampler_index)];
- GetOriginatingResources(
- std::array<const ast::Expression*, 2>{t, s},
- [&](std::array<const sem::GlobalVariable*, 2> globals) {
- auto texture_binding_point = globals[0]->BindingPoint();
- auto sampler_binding_point = globals[1]->BindingPoint();
+ GetOriginatingResources(std::array<const ast::Expression*, 2>{t, s},
+ [&](std::array<const sem::GlobalVariable*, 2> globals) {
+ auto texture_binding_point = globals[0]->BindingPoint();
+ auto sampler_binding_point = globals[1]->BindingPoint();
- for (auto* entry_point : entry_points) {
- const auto& ep_name =
- program_->Symbols().NameFor(entry_point->Declaration()->symbol);
- (*sampler_targets_)[ep_name].Add(
- {sampler_binding_point, texture_binding_point});
- }
- });
+ for (auto* entry_point : entry_points) {
+ const auto& ep_name = program_->Symbols().NameFor(
+ entry_point->Declaration()->symbol);
+ (*sampler_targets_)[ep_name].Add(
+ {sampler_binding_point, texture_binding_point});
+ }
+ });
}
}
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 22f2620..cf87296 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -468,7 +468,7 @@
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
// const x = 4u;
- // const x = 8u;
+ // const y = 8u;
// @compute @workgroup_size(x, y, 16u)
// fn main() {}
auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
@@ -489,10 +489,29 @@
ASSERT_NE(sem_x, nullptr);
ASSERT_NE(sem_y, nullptr);
+ EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{4u, 8u, 16u}));
+
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_x));
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_y));
}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Cast) {
+ // @compute @workgroup_size(i32(5))
+ // fn main() {}
+ auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 5_a)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_func = Sem().Get(func);
+
+ ASSERT_NE(sem_func, nullptr);
+ EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{5u, 1u, 1u}));
+}
+
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
// @compute @workgroup_size(1i, 2i, 3i)
// fn main() {}
@@ -651,9 +670,10 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size argument must be a constant or override expression of type "
+ "abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
@@ -696,9 +716,10 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size argument must be a constant or override expression of type "
+ "abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
@@ -759,8 +780,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ "12:34 error: workgroup_size argument must be a constant or override expression of "
+ "type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
@@ -774,8 +795,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ "12:34 error: workgroup_size argument must be a constant or override expression of "
+ "type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
@@ -789,8 +810,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ "12:34 error: workgroup_size argument must be a constant or override expression of "
+ "type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
@@ -804,8 +825,8 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size argument must be either a literal, constant, or "
- "overridable of type abstract-integer, i32 or u32");
+ "12:34 error: workgroup_size argument must be a constant or override expression of "
+ "type abstract-integer, i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 911b66c..c1920d8 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1050,8 +1050,7 @@
// Set work-group size defaults.
sem::WorkgroupSize ws;
for (size_t i = 0; i < 3; i++) {
- ws[i].value = 1;
- ws[i].overridable_const = nullptr;
+ ws[i] = 1;
}
auto* attr = ast::GetAttribute<ast::WorkgroupAttribute>(func->attributes);
@@ -1064,7 +1063,7 @@
utils::Vector<const sem::Type*, 3> arg_tys;
constexpr const char* kErrBadExpr =
- "workgroup_size argument must be either a literal, constant, or overridable of type "
+ "workgroup_size argument must be a constant or override expression of type "
"abstract-integer, i32 or u32";
for (size_t i = 0; i < 3; i++) {
@@ -1084,6 +1083,12 @@
return false;
}
+ if (expr->Stage() != sem::EvaluationStage::kConstant &&
+ expr->Stage() != sem::EvaluationStage::kOverride) {
+ AddError(kErrBadExpr, value->source);
+ return false;
+ }
+
args.Push(expr);
arg_tys.Push(ty);
}
@@ -1105,47 +1110,15 @@
if (!materialized) {
return false;
}
-
- const sem::Constant* value = nullptr;
-
- if (auto* user = args[i]->As<sem::VariableUser>()) {
- // We have an variable of a module-scope constant.
- auto* decl = user->Variable()->Declaration();
- if (!decl->IsAnyOf<ast::Const, ast::Override>()) {
- AddError(kErrBadExpr, values[i]->source);
+ if (auto* value = materialized->ConstantValue()) {
+ if (value->As<AInt>() < 1) {
+ AddError("workgroup_size argument must be at least 1", values[i]->source);
return false;
}
- // Capture the constant if it is pipeline-overridable.
- if (decl->Is<ast::Override>()) {
- ws[i].overridable_const = decl;
- }
-
- if (decl->constructor) {
- value = sem_.Get(decl->constructor)->ConstantValue();
- } else {
- // No constructor means this value must be overriden by the user.
- ws[i].value = 0;
- continue;
- }
- } else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
- value = materialized->ConstantValue();
+ ws[i] = value->As<uint32_t>();
} else {
- AddError(kErrBadExpr, values[i]->source);
- return false;
+ ws[i] = std::nullopt;
}
-
- if (!value) {
- TINT_ICE(Resolver, diagnostics_)
- << "could not resolve constant workgroup_size constant value";
- continue;
- }
- // validator_.Validate and set the default value for this dimension.
- if (value->As<AInt>() < 1) {
- AddError("workgroup_size argument must be at least 1", values[i]->source);
- return false;
- }
-
- ws[i].value = value->As<uint32_t>();
}
current_function_->SetWorkgroupSize(std::move(ws));
@@ -2280,7 +2253,7 @@
// Note: The spec is currently vague around the rules here. See
// https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when resolved.
std::string desc = "var '" + builder_->Symbols().NameFor(symbol) + "' ";
- AddError(desc + "cannot not be referenced at module-scope", expr->source);
+ AddError(desc + "cannot be referenced at module-scope", expr->source);
AddNote(desc + "declared here", variable->Declaration()->source);
return nullptr;
}
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 1a3c623..79cc139 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -993,12 +993,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 1u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 1u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], 1u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u);
}
TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
@@ -1015,12 +1012,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], 2u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u);
}
TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) {
@@ -1043,12 +1037,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], 16u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], 2u);
}
TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) {
@@ -1071,12 +1062,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 4u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], 4u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u);
}
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
@@ -1085,9 +1073,9 @@
// @id(2) override depth = 2i;
// @compute @workgroup_size(width, height, depth)
// fn main() {}
- auto* width = Override("width", ty.i32(), Expr(16_i), Id(0_a));
- auto* height = Override("height", ty.i32(), Expr(8_i), Id(1_a));
- auto* depth = Override("depth", ty.i32(), Expr(2_i), Id(2_a));
+ Override("width", ty.i32(), Expr(16_i), Id(0_a));
+ Override("height", ty.i32(), Expr(8_i), Id(1_a));
+ Override("depth", ty.i32(), Expr(2_i), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
@@ -1099,12 +1087,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt);
}
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
@@ -1113,9 +1098,9 @@
// @id(2) override depth : i32;
// @compute @workgroup_size(width, height, depth)
// fn main() {}
- auto* width = Override("width", ty.i32(), Id(0_a));
- auto* height = Override("height", ty.i32(), Id(1_a));
- auto* depth = Override("depth", ty.i32(), Id(2_a));
+ Override("width", ty.i32(), Id(0_a));
+ Override("height", ty.i32(), Id(1_a));
+ Override("depth", ty.i32(), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
@@ -1127,12 +1112,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 0u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 0u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 0u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt);
}
TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
@@ -1140,7 +1122,7 @@
// const depth = 3i;
// @compute @workgroup_size(8, height, depth)
// fn main() {}
- auto* height = Override("height", ty.i32(), Expr(2_i), Id(0_a));
+ Override("height", ty.i32(), Expr(2_i), Id(0_a));
GlobalConst("depth", ty.i32(), Expr(3_i));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
@@ -1153,12 +1135,9 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
- EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
- EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
- EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
+ EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
+ EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u);
}
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index c5b7220..a5f7f37 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -342,7 +342,7 @@
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- R"(12:34 error: var 'size' cannot not be referenced at module-scope
+ R"(12:34 error: var 'size' cannot be referenced at module-scope
note: var 'size' declared here)");
}
diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc
index 9843d2f..ea67071 100644
--- a/src/tint/resolver/variable_validation_test.cc
+++ b/src/tint/resolver/variable_validation_test.cc
@@ -65,7 +65,7 @@
GlobalVar("b", ty.i32(), ast::StorageClass::kPrivate, Expr(Source{{56, 78}}, "a"));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(56:78 error: var 'a' cannot not be referenced at module-scope
+ EXPECT_EQ(r()->error(), R"(56:78 error: var 'a' cannot be referenced at module-scope
12:34 note: var 'a' declared here)");
}
diff --git a/src/tint/sem/function.cc b/src/tint/sem/function.cc
index 6562526..fc7809b 100644
--- a/src/tint/sem/function.cc
+++ b/src/tint/sem/function.cc
@@ -44,7 +44,7 @@
utils::VectorRef<Parameter*> parameters)
: Base(return_type, SetOwner(std::move(parameters), this), EvaluationStage::kRuntime),
declaration_(declaration),
- workgroup_size_{WorkgroupDimension{1}, WorkgroupDimension{1}, WorkgroupDimension{1}},
+ workgroup_size_{1, 1, 1},
return_location_(return_location) {}
Function::~Function() = default;
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index 3f7256a..50d853c 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -39,18 +39,10 @@
namespace tint::sem {
-/// WorkgroupDimension describes the size of a single dimension of an entry
-/// point's workgroup size.
-struct WorkgroupDimension {
- /// The size of this dimension.
- uint32_t value;
- /// A pipeline-overridable constant that overrides the size, or nullptr if
- /// this dimension is not overridable.
- const ast::Variable* overridable_const = nullptr;
-};
-
/// WorkgroupSize is a three-dimensional array of WorkgroupDimensions.
-using WorkgroupSize = std::array<WorkgroupDimension, 3>;
+/// Each dimension is a std::optional as a workgroup size can be a constant or override expression.
+/// Override expressions are not known at compilation time, so these will be std::nullopt.
+using WorkgroupSize = std::array<std::optional<uint32_t>, 3>;
/// Function holds the semantic information for function nodes.
class Function final : public Castable<Function, CallTarget> {
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index e332495..21840bc 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -102,7 +102,6 @@
namespace {
const char kTempNamePrefix[] = "tint_tmp";
-const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
@@ -1886,8 +1885,9 @@
[&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override*) {
// Override is removed with SubstituteOverride
- TINT_ICE(Writer, diagnostics_)
- << "Override should have been removed by the substitute_override transform.";
+ diagnostics_.add_error(diag::System::Writer,
+ "override expressions should have been removed with the "
+ "SubstituteOverride transform");
return false;
},
[&](const ast::Const*) {
@@ -2104,16 +2104,14 @@
}
out << "local_size_" << (i == 0 ? "x" : i == 1 ? "y" : "z") << " = ";
- if (wgsize[i].overridable_const) {
- auto* global = builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
- if (!global->Declaration()->Is<ast::Override>()) {
- TINT_ICE(Writer, builder_.Diagnostics())
- << "expected a pipeline-overridable constant";
- }
- out << kSpecConstantPrefix << global->OverrideId().value;
- } else {
- out << std::to_string(wgsize[i].value);
+ if (!wgsize[i].has_value()) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "override expressions should have been removed with the SubstituteOverride "
+ "transform");
+ return false;
}
+ out << std::to_string(wgsize[i].value());
}
out << ") in;";
}
diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc
index eb0a9a8..6afac16 100644
--- a/src/tint/writer/glsl/generator_impl_function_test.cc
+++ b/src/tint/writer/glsl/generator_impl_function_test.cc
@@ -783,6 +783,25 @@
)");
}
+TEST_F(GlslGeneratorImplTest_Function,
+ Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
+ Func("main", utils::Empty, ty.void_(), {},
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth"),
+ });
+
+ GeneratorImpl& gen = Build();
+
+ EXPECT_FALSE(gen.Generate()) << gen.error();
+ EXPECT_EQ(
+ gen.error(),
+ R"(error: override expressions should have been removed with the SubstituteOverride transform)");
+}
+
TEST_F(GlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func", utils::Vector{Param("a", ty.array<f32, 5>())}, ty.void_(),
utils::Vector{
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 3186955..516acdd 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -81,7 +81,6 @@
namespace {
const char kTempNamePrefix[] = "tint_tmp";
-const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
const char* image_format_to_rwtexture_type(ast::TexelFormat image_format) {
switch (image_format) {
@@ -2842,8 +2841,9 @@
},
[&](const ast::Override*) {
// Override is removed with SubstituteOverride
- TINT_ICE(Writer, diagnostics_)
- << "Override should have been removed by the substitute_override transform.";
+ diagnostics_.add_error(diag::System::Writer,
+ "override expressions should have been removed with the "
+ "SubstituteOverride transform");
return false;
},
[&](const ast::Const*) {
@@ -3044,18 +3044,14 @@
if (i > 0) {
out << ", ";
}
-
- if (wgsize[i].overridable_const) {
- auto* global =
- builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
- if (!global->Declaration()->Is<ast::Override>()) {
- TINT_ICE(Writer, diagnostics_)
- << "expected a pipeline-overridable constant";
- }
- out << kSpecConstantPrefix << global->OverrideId().value;
- } else {
- out << std::to_string(wgsize[i].value);
+ if (!wgsize[i].has_value()) {
+ diagnostics_.add_error(
+ diag::System::Writer,
+ "override expressions should have been removed with the SubstituteOverride "
+ "transform");
+ return false;
}
+ out << std::to_string(wgsize[i].value());
}
out << ")]" << std::endl;
}
diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc
index 5c167fa..322b560 100644
--- a/src/tint/writer/hlsl/generator_impl_function_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_function_test.cc
@@ -712,6 +712,25 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Function,
+ Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth"),
+ });
+
+ GeneratorImpl& gen = Build();
+
+ EXPECT_FALSE(gen.Generate()) << gen.error();
+ EXPECT_EQ(
+ gen.error(),
+ R"(error: override expressions should have been removed with the SubstituteOverride transform)");
+}
+
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func",
utils::Vector{
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 5ff5751..d1bab39 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -273,8 +273,9 @@
},
[&](const ast::Override*) {
// Override is removed with SubstituteOverride
- TINT_ICE(Writer, diagnostics_)
- << "Override should have been removed by the substitute_override transform.";
+ diagnostics_.add_error(diag::System::Writer,
+ "override expressions should have been removed with the "
+ "SubstituteOverride transform.");
return false;
},
[&](const ast::Function* func) {
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index f51d808..05f3716 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -506,13 +506,17 @@
} else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->WorkgroupSize();
- // SubstituteOverride replaced all overrides with constants.
- uint32_t x = wgsize[0].value;
- uint32_t y = wgsize[1].value;
- uint32_t z = wgsize[2].value;
- push_execution_mode(spv::Op::OpExecutionMode,
- {Operand(id), U32Operand(SpvExecutionModeLocalSize), Operand(x),
- Operand(y), Operand(z)});
+ // Check if the workgroup_size uses pipeline-overridable constants.
+ if (!wgsize[0].has_value() || !wgsize[1].has_value() || !wgsize[2].has_value()) {
+ error_ =
+ "override expressions should have been removed with the SubstituteOverride "
+ "transform";
+ return false;
+ }
+ push_execution_mode(
+ spv::Op::OpExecutionMode,
+ {Operand(id), U32Operand(SpvExecutionModeLocalSize), //
+ Operand(wgsize[0].value()), Operand(wgsize[1].value()), Operand(wgsize[2].value())});
}
for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {
diff --git a/src/tint/writer/spirv/builder_function_attribute_test.cc b/src/tint/writer/spirv/builder_function_attribute_test.cc
index 9e6c338..60bf062 100644
--- a/src/tint/writer/spirv/builder_function_attribute_test.cc
+++ b/src/tint/writer/spirv/builder_function_attribute_test.cc
@@ -149,6 +149,41 @@
)");
}
+TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) {
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
+ auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ WorkgroupSize("width", "height", "depth"),
+ Stage(ast::PipelineStage::kCompute),
+ });
+
+ spirv::Builder& b = Build();
+
+ EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error();
+ EXPECT_EQ(
+ b.error(),
+ R"(override expressions should have been removed with the SubstituteOverride transform)");
+}
+
+TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) {
+ Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3_i));
+ auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ WorkgroupSize(4_i, "height", "depth"),
+ Stage(ast::PipelineStage::kCompute),
+ });
+
+ spirv::Builder& b = Build();
+
+ EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error();
+ EXPECT_EQ(
+ b.error(),
+ R"(override expressions should have been removed with the SubstituteOverride transform)");
+}
+
TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) {
auto* func1 = Func("main1", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{