[dawn] Fix validation for override not in module.
According to the WebGPU spec, when setting constants into the pipeline:
`The pipeline-overridable constant is not required to be statically
used by entryPoint.` This is currently treated as an error in Dawn and
the pipeline fails.
Remove the validation for the Id to be present in the module.
Bug: 338624452
Change-Id: I4bdc3e6cddf032695541787d2de9e3dbaa1d5410
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/231835
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 6f51472..16ae996 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -643,9 +643,8 @@
return invalid; \
})()
+ const auto& name2Id = inspector->GetNamedOverrideIds();
if (!entryPoint.overrides.empty()) {
- const auto& name2Id = inspector->GetNamedOverrideIds();
-
for (auto& c : entryPoint.overrides) {
auto id = name2Id.at(c.name);
EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type),
@@ -667,6 +666,20 @@
}
}
+ // Add overrides which are not used by the entry point into the list so we
+ // can validate set constants in the pipeline.
+ for (auto& o : inspector->Overrides()) {
+ std::string identifier = o.is_id_specified ? std::to_string(o.id.value) : o.name;
+ if (metadata->overrides.count(identifier) != 0) {
+ continue;
+ }
+
+ auto id = name2Id.at(o.name);
+ EntryPointMetadata::Override override = {id, FromTintOverrideType(o.type), o.is_initialized,
+ /* isUsed */ false};
+ metadata->overrides[identifier] = override;
+ }
+
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) {
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index e852e39..c9ad6b5 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -265,6 +265,9 @@
// Then it is required for the pipeline stage to have a constant record to initialize a
// value
bool isInitialized;
+
+ // Set to true if the override is used in the entry point
+ bool isUsed = true;
};
using OverridesMap = absl::flat_hash_map<std::string, Override>;
diff --git a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp
index a178674..86399a6 100644
--- a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp
@@ -60,6 +60,14 @@
override c12: f16 = 0.0h; // default override
@id(1000) override c13: f16 = 4.0h; // default
+override u01: f16 = 0.0h; // default override
+@id(2000) override u02: f16 = 0.0h; // default override
+
+override u10: f32;
+override u11: i32;
+override u12: u32;
+override u13: f16;
+
@compute @workgroup_size(1) fn main() {
// make sure the overridable constants are not optimized out
_ = u32(c0);
@@ -148,6 +156,16 @@
TestCreatePipeline(constants);
}
{
+ // Valid: in module but unused by entry point.
+ std::vector<wgpu::ConstantEntry> constants{{nullptr, "u01", 0}};
+ TestCreatePipeline(constants);
+ }
+ {
+ // Valid: in module but unused by entry point.
+ std::vector<wgpu::ConstantEntry> constants{{nullptr, "2000", 0}};
+ TestCreatePipeline(constants);
+ }
+ {
// Error: set the same constant twice
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c0", 0},
@@ -161,7 +179,7 @@
TestCreatePipeline(constants);
}
{
- // Error: c10 already has a constant numeric id specified
+ // Error: c13 already has a constant numeric id specified
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c13", 0}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
@@ -383,5 +401,86 @@
}
}
+// Test that values that are not representable by WGSL type i32/u32/f16/f32 for unused overrides are
+// errors
+TEST_F(ComputePipelineOverridableConstantsValidationTest, UnusedOutofRangeValue) {
+ SetUpShadersWithDefaultValueConstants();
+ {
+ // Error: 1.79769e+308 cannot be represented by f32
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u10", std::numeric_limits<double>::max()}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Valid: max f32 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u10", std::numeric_limits<float>::max()}};
+ TestCreatePipeline(constants);
+ }
+ {
+ // Error: one ULP higher than max f32 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u10",
+ std::nextafter<double>(std::numeric_limits<float>::max(),
+ std::numeric_limits<double>::max())}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Valid: lowest f32 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u10", std::numeric_limits<float>::lowest()}};
+ TestCreatePipeline(constants);
+ }
+ {
+ // Error: one ULP lower than lowest f32 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u10",
+ std::nextafter<double>(std::numeric_limits<float>::lowest(),
+ std::numeric_limits<double>::lowest())}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Error: i32 out of range
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u11", static_cast<double>(std::numeric_limits<int32_t>::max()) + 1.0}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Error: i32 out of range (negative)
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u11", static_cast<double>(std::numeric_limits<int32_t>::lowest()) - 1.0}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Error: u32 out of range
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u12", static_cast<double>(std::numeric_limits<uint32_t>::max()) + 1.0}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Valid: max f16 representable value
+ std::vector<wgpu::ConstantEntry> constants{{nullptr, "c11", 65504.0}};
+ TestCreatePipeline(constants);
+ }
+ {
+ // Error: one ULP higher than max f16 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u13", std::nextafter<double>(65504.0, std::numeric_limits<double>::max())}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+ {
+ // Valid: lowest f16 representable value
+ std::vector<wgpu::ConstantEntry> constants{{nullptr, "u13", -65504.0}};
+ TestCreatePipeline(constants);
+ }
+ {
+ // Error: one ULP lower than lowest f16 representable value
+ std::vector<wgpu::ConstantEntry> constants{
+ {nullptr, "u13",
+ std::nextafter<double>(-65504.0, std::numeric_limits<double>::lowest())}};
+ ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
+ }
+}
+
} // anonymous namespace
} // namespace dawn
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 022e0ee..a0dde13 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -220,6 +220,35 @@
return result;
}
+inspector::Override MkOverride(const sem::GlobalVariable* global, OverrideId id) {
+ Override override;
+ override.name = global->Declaration()->name->symbol.Name();
+ override.id = id;
+
+ auto* type = global->Type();
+ TINT_ASSERT(type->Is<core::type::Scalar>());
+ if (type->IsBoolScalarOrVector()) {
+ override.type = Override::Type::kBool;
+ } else if (type->IsFloatScalar()) {
+ if (type->Is<core::type::F16>()) {
+ override.type = Override::Type::kFloat16;
+ } else {
+ override.type = Override::Type::kFloat32;
+ }
+ } else if (type->IsSignedIntegerScalar()) {
+ override.type = Override::Type::kInt32;
+ } else if (type->IsUnsignedIntegerScalar()) {
+ override.type = Override::Type::kUint32;
+ } else {
+ TINT_UNREACHABLE();
+ }
+
+ override.is_initialized = global->Declaration()->initializer;
+ override.is_id_specified =
+ ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
+ return override;
+}
+
} // namespace
Inspector::Inspector(const Program& program) : program_(program) {}
@@ -303,38 +332,9 @@
}
for (auto* var : sem->TransitivelyReferencedGlobals()) {
- auto* decl = var->Declaration();
-
- auto name = decl->name->symbol.Name();
-
auto* global = var->As<sem::GlobalVariable>();
if (auto override_id = global->Attributes().override_id) {
- Override override;
- override.name = name;
- override.id = override_id.value();
- auto* type = var->Type();
- TINT_ASSERT(type->Is<core::type::Scalar>());
- if (type->IsBoolScalarOrVector()) {
- override.type = Override::Type::kBool;
- } else if (type->IsFloatScalar()) {
- if (type->Is<core::type::F16>()) {
- override.type = Override::Type::kFloat16;
- } else {
- override.type = Override::Type::kFloat32;
- }
- } else if (type->IsSignedIntegerScalar()) {
- override.type = Override::Type::kInt32;
- } else if (type->IsUnsignedIntegerScalar()) {
- override.type = Override::Type::kUint32;
- } else {
- TINT_UNREACHABLE();
- }
-
- override.is_initialized = global->Declaration()->initializer;
- override.is_id_specified =
- ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
-
- entry_point.overrides.push_back(override);
+ entry_point.overrides.push_back(MkOverride(global, override_id.value()));
}
}
@@ -405,7 +405,7 @@
continue;
}
- // If there are conflicting defintions for an override id, that is invalid
+ // If there are conflicting definitions for an override id, that is invalid
// WGSL, so the resolver should catch it. Thus here the inspector just
// assumes all definitions of the override id are the same, so only needs
// to find the first reference to override id.
@@ -1056,4 +1056,18 @@
return false;
}
+std::vector<Override> Inspector::Overrides() {
+ std::vector<Override> results;
+
+ for (auto* var : program_.AST().GlobalVariables()) {
+ auto* global = program_.Sem().Get<sem::GlobalVariable>(var);
+ if (!global || !global->Declaration()->Is<ast::Override>()) {
+ continue;
+ }
+
+ results.push_back(MkOverride(global, global->Attributes().override_id.value()));
+ }
+ return results;
+}
+
} // namespace tint::inspector
diff --git a/src/tint/lang/wgsl/inspector/inspector.h b/src/tint/lang/wgsl/inspector/inspector.h
index 928dae9..0f04bd8 100644
--- a/src/tint/lang/wgsl/inspector/inspector.h
+++ b/src/tint/lang/wgsl/inspector/inspector.h
@@ -80,6 +80,9 @@
/// @returns map of module-constant name to pipeline constant ID
std::map<std::string, OverrideId> GetNamedOverrideIds();
+ /// @returns vector of all overrides
+ std::vector<Override> Overrides();
+
/// @param entry_point name of the entry point to get information about.
/// @returns vector of all of the resource bindings.
std::vector<ResourceBinding> GetResourceBindings(const std::string& entry_point);
diff --git a/src/tint/lang/wgsl/inspector/inspector_test.cc b/src/tint/lang/wgsl/inspector/inspector_test.cc
index 80c4163..eb65ed0 100644
--- a/src/tint/lang/wgsl/inspector/inspector_test.cc
+++ b/src/tint/lang/wgsl/inspector/inspector_test.cc
@@ -63,6 +63,7 @@
// returned Inspector from ::Initialize can then be used to test expectations.
class InspectorGetEntryPointTest : public InspectorBuilder, public testing::Test {};
+class InspectorOverridesTest : public InspectorBuilder, public testing::Test {};
typedef std::tuple<inspector::ComponentType, inspector::CompositionType>
InspectorGetEntryPointComponentAndCompositionTestParams;
@@ -1758,6 +1759,47 @@
core::InterpolationType::kFlat, core::InterpolationSampling::kEither,
InterpolationType::kFlat, InterpolationSampling::kEither}));
+TEST_F(InspectorOverridesTest, NoOverrides) {
+ MakeCallerBodyFunction("ep_func", Empty,
+ Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1_i),
+ });
+
+ Inspector& inspector = Build();
+
+ auto result = inspector.Overrides();
+ EXPECT_TRUE(result.empty());
+}
+
+TEST_F(InspectorOverridesTest, Multiple) {
+ Override("foo", ty.f32(), Id(1_a));
+ Override("bar", ty.f32(), Id(2_a));
+ MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), tint::Empty);
+ MakeCallerBodyFunction("ep_func", Vector{std::string("callee_func")},
+ Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1_i),
+ });
+
+ Inspector& inspector = Build();
+
+ auto result = inspector.Overrides();
+ ASSERT_EQ(2u, result.size());
+
+ auto& ep = result[0];
+ EXPECT_EQ(ep.name, "foo");
+ EXPECT_EQ(ep.id.value, 1);
+ EXPECT_FALSE(ep.is_initialized);
+ EXPECT_TRUE(ep.is_id_specified);
+
+ ep = result[1];
+ EXPECT_EQ(ep.name, "bar");
+ EXPECT_EQ(ep.id.value, 2);
+ EXPECT_FALSE(ep.is_initialized);
+ EXPECT_TRUE(ep.is_id_specified);
+}
+
TEST_F(InspectorGetOverrideDefaultValuesTest, Bool) {
GlobalConst("C", Expr(true));
Override("a", ty.bool_(), Id(1_a));