Implement entryPoint defaulting
Spec: https://github.com/gpuweb/gpuweb/pull/4387
Bug: dawn:2254
Change-Id: I51c67253b095f5daf59dac378624b6bc38fe524d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/161901
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Fr <beaufort.francois@gmail.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/dawn.json b/dawn.json
index 7a80ae8..fc76c12 100644
--- a/dawn.json
+++ b/dawn.json
@@ -2228,7 +2228,7 @@
"extensible": "in",
"members": [
{"name": "module", "type": "shader module"},
- {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+ {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
{"name": "constant count", "type": "size_t", "default": 0},
{"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"}
]
@@ -2832,7 +2832,7 @@
"extensible": "in",
"members": [
{"name": "module", "type": "shader module"},
- {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+ {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
{"name": "constant count", "type": "size_t", "default": 0},
{"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"},
{"name": "buffer count", "type": "size_t", "default": 0},
@@ -2912,7 +2912,7 @@
"extensible": "in",
"members": [
{"name": "module", "type": "shader module"},
- {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+ {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
{"name": "constant count", "type": "size_t", "default": 0},
{"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"},
{"name": "target count", "type": "size_t"},
diff --git a/src/dawn/native/ComputePipeline.cpp b/src/dawn/native/ComputePipeline.cpp
index cbb5cb6..b3b99d9 100644
--- a/src/dawn/native/ComputePipeline.cpp
+++ b/src/dawn/native/ComputePipeline.cpp
@@ -43,12 +43,14 @@
DAWN_TRY(device->ValidateObject(descriptor->layout));
}
- DAWN_TRY_CONTEXT(ValidateProgrammableStage(
- device, descriptor->compute.module, descriptor->compute.entryPoint,
- descriptor->compute.constantCount, descriptor->compute.constants,
- descriptor->layout, SingleShaderStage::Compute),
- "validating compute stage (%s, entryPoint: %s).", descriptor->compute.module,
- descriptor->compute.entryPoint);
+ ShaderModuleEntryPoint entryPoint;
+ DAWN_TRY_ASSIGN_CONTEXT(entryPoint,
+ ValidateProgrammableStage(
+ device, descriptor->compute.module, descriptor->compute.entryPoint,
+ descriptor->compute.constantCount, descriptor->compute.constants,
+ descriptor->layout, SingleShaderStage::Compute),
+ "validating compute stage (%s, entryPoint: %s).",
+ descriptor->compute.module, descriptor->compute.entryPoint);
return {};
}
diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp
index 347c738..ff1f5b9 100644
--- a/src/dawn/native/Pipeline.cpp
+++ b/src/dawn/native/Pipeline.cpp
@@ -47,33 +47,46 @@
} // namespace
namespace dawn::native {
-MaybeError ValidateProgrammableStage(DeviceBase* device,
- const ShaderModuleBase* module,
- const std::string& entryPoint,
- uint32_t constantCount,
- const ConstantEntry* constants,
- const PipelineLayoutBase* layout,
- SingleShaderStage stage) {
+ResultOrError<ShaderModuleEntryPoint> ValidateProgrammableStage(DeviceBase* device,
+ const ShaderModuleBase* module,
+ const char* entryPointName,
+ uint32_t constantCount,
+ const ConstantEntry* constants,
+ const PipelineLayoutBase* layout,
+ SingleShaderStage stage) {
DAWN_TRY(device->ValidateObject(module));
- DAWN_INVALID_IF(!module->HasEntryPoint(entryPoint),
- "Entry point \"%s\" doesn't exist in the shader module %s.", entryPoint,
- module);
+ if (entryPointName) {
+ DAWN_INVALID_IF(!module->HasEntryPoint(entryPointName),
+ "Entry point \"%s\" doesn't exist in the shader module %s.", entryPointName,
+ module);
+ } else {
+ size_t entryPointCount = module->GetEntryPointCount(stage);
+ if (entryPointCount == 0) {
+ return DAWN_VALIDATION_ERROR(
+ "Compatible entry point for stage (%s) doesn't exist in the shader module %s.",
+ stage, module);
+ } else if (entryPointCount > 1) {
+ return DAWN_VALIDATION_ERROR(
+ "Multiple entry points for stage (%s) exist in the shader module %s.", stage,
+ module);
+ }
+ }
- const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint);
+ ShaderModuleEntryPoint entryPoint = module->ReifyEntryPointName(entryPointName, stage);
+ const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint.name);
if (!metadata.infringedLimitErrors.empty()) {
std::ostringstream limitList;
for (const std::string& limit : metadata.infringedLimitErrors) {
limitList << " - " << limit << "\n";
}
- return DAWN_VALIDATION_ERROR("Entry point \"%s\" infringes limits:\n%s", entryPoint,
- limitList.str());
+ return DAWN_VALIDATION_ERROR("%s infringes limits:\n%s", &entryPoint, limitList.str());
}
DAWN_INVALID_IF(metadata.stage != stage,
"The stage (%s) of the entry point \"%s\" isn't the expected one (%s).",
- metadata.stage, entryPoint, stage);
+ metadata.stage, entryPoint.name, stage);
if (layout != nullptr) {
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
@@ -163,7 +176,7 @@
uninitializedConstantsArray);
}
- return {};
+ return entryPoint;
}
WGPUCreatePipelineAsyncStatus CreatePipelineAsyncStatusFromErrorType(InternalErrorType error) {
diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h
index 7bac9f7..c2234d2 100644
--- a/src/dawn/native/Pipeline.h
+++ b/src/dawn/native/Pipeline.h
@@ -45,13 +45,13 @@
namespace dawn::native {
-MaybeError ValidateProgrammableStage(DeviceBase* device,
- const ShaderModuleBase* module,
- const std::string& entryPoint,
- uint32_t constantCount,
- const ConstantEntry* constants,
- const PipelineLayoutBase* layout,
- SingleShaderStage stage);
+ResultOrError<ShaderModuleEntryPoint> ValidateProgrammableStage(DeviceBase* device,
+ const ShaderModuleBase* module,
+ const char* entryPointName,
+ uint32_t constantCount,
+ const ConstantEntry* constants,
+ const PipelineLayoutBase* layout,
+ SingleShaderStage stage);
WGPUCreatePipelineAsyncStatus CreatePipelineAsyncStatusFromErrorType(InternalErrorType error);
diff --git a/src/dawn/native/PipelineLayout.cpp b/src/dawn/native/PipelineLayout.cpp
index 4a4180c..165a070 100644
--- a/src/dawn/native/PipelineLayout.cpp
+++ b/src/dawn/native/PipelineLayout.cpp
@@ -92,6 +92,17 @@
return {};
}
+StageAndDescriptor::StageAndDescriptor(SingleShaderStage shaderStage,
+ ShaderModuleBase* module,
+ const char* entryPoint,
+ size_t constantCount,
+ ConstantEntry const* constants)
+ : shaderStage(shaderStage),
+ module(module),
+ entryPoint(module->ReifyEntryPointName(entryPoint, shaderStage).name),
+ constantCount(constantCount),
+ constants(constants) {}
+
// PipelineLayoutBase
PipelineLayoutBase::PipelineLayoutBase(DeviceBase* device,
diff --git a/src/dawn/native/PipelineLayout.h b/src/dawn/native/PipelineLayout.h
index 756f8bb..8f7dcbd 100644
--- a/src/dawn/native/PipelineLayout.h
+++ b/src/dawn/native/PipelineLayout.h
@@ -56,6 +56,12 @@
using BindGroupLayoutMask = ityp::bitset<BindGroupIndex, kMaxBindGroups>;
struct StageAndDescriptor {
+ StageAndDescriptor(SingleShaderStage shaderStage,
+ ShaderModuleBase* module,
+ const char* entryPoint,
+ size_t constantCount,
+ ConstantEntry const* constants);
+
SingleShaderStage shaderStage;
ShaderModuleBase* module;
std::string entryPoint;
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index 8286dd6..fdf2f1c 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -173,10 +173,11 @@
return {};
}
-MaybeError ValidateVertexState(DeviceBase* device,
- const VertexState* descriptor,
- const PipelineLayoutBase* layout,
- wgpu::PrimitiveTopology primitiveTopology) {
+ResultOrError<ShaderModuleEntryPoint> ValidateVertexState(
+ DeviceBase* device,
+ const VertexState* descriptor,
+ const PipelineLayoutBase* layout,
+ wgpu::PrimitiveTopology primitiveTopology) {
DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");
const CombinedLimits& limits = device->GetLimits();
@@ -185,13 +186,15 @@
"Vertex buffer count (%u) exceeds the maximum number of vertex buffers (%u).",
descriptor->bufferCount, limits.v1.maxVertexBuffers);
- DAWN_TRY_CONTEXT(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
- descriptor->constantCount, descriptor->constants,
- layout, SingleShaderStage::Vertex),
- "validating vertex stage (%s, entryPoint: %s).", descriptor->module,
- descriptor->entryPoint);
- const EntryPointMetadata& vertexMetadata =
- descriptor->module->GetEntryPoint(descriptor->entryPoint);
+ ShaderModuleEntryPoint entryPoint;
+ DAWN_TRY_ASSIGN_CONTEXT(
+ entryPoint,
+ ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+ descriptor->constantCount, descriptor->constants, layout,
+ SingleShaderStage::Vertex),
+ "validating vertex stage (%s, entryPoint: %s).", descriptor->module,
+ descriptor->entryPoint);
+ const EntryPointMetadata& vertexMetadata = descriptor->module->GetEntryPoint(entryPoint.name);
if (primitiveTopology == wgpu::PrimitiveTopology::PointList) {
DAWN_INVALID_IF(
vertexMetadata.totalInterStageShaderComponents + 1 >
@@ -239,12 +242,12 @@
VertexAttributeLocation firstMissing = ityp::Sub(
GetHighestBitIndexPlusOne(missingAttributes), VertexAttributeLocation(uint8_t(1)));
return DAWN_VALIDATION_ERROR(
- "Vertex attribute slot %u used in (%s, entryPoint: %s) is not present in the "
+ "Vertex attribute slot %u used in (%s, %s) is not present in the "
"VertexState.",
- uint8_t(firstMissing), descriptor->module, descriptor->entryPoint);
+ uint8_t(firstMissing), descriptor->module, &entryPoint);
}
- return {};
+ return entryPoint;
}
MaybeError ValidatePrimitiveState(const DeviceBase* device, const PrimitiveState* descriptor) {
@@ -543,35 +546,36 @@
return {};
}
-MaybeError ValidateFragmentState(DeviceBase* device,
- const FragmentState* descriptor,
- const PipelineLayoutBase* layout,
- const DepthStencilState* depthStencil,
- const MultisampleState& multisample) {
+ResultOrError<ShaderModuleEntryPoint> ValidateFragmentState(DeviceBase* device,
+ const FragmentState* descriptor,
+ const PipelineLayoutBase* layout,
+ const DepthStencilState* depthStencil,
+ const MultisampleState& multisample) {
DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");
- DAWN_TRY_CONTEXT(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
- descriptor->constantCount, descriptor->constants,
- layout, SingleShaderStage::Fragment),
- "validating fragment stage (%s, entryPoint: %s).", descriptor->module,
- descriptor->entryPoint);
+ ShaderModuleEntryPoint entryPoint;
+ DAWN_TRY_ASSIGN_CONTEXT(
+ entryPoint,
+ ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+ descriptor->constantCount, descriptor->constants, layout,
+ SingleShaderStage::Fragment),
+ "validating fragment stage (%s, entryPoint: %s).", descriptor->module,
+ descriptor->entryPoint);
- const EntryPointMetadata& fragmentMetadata =
- descriptor->module->GetEntryPoint(descriptor->entryPoint);
+ const EntryPointMetadata& fragmentMetadata = descriptor->module->GetEntryPoint(entryPoint.name);
if (fragmentMetadata.usesFragDepth) {
- DAWN_INVALID_IF(
- depthStencil == nullptr,
- "Depth stencil state is not present when fragment stage (%s, entryPoint: %s) is "
- "writing to frag_depth.",
- descriptor->module, descriptor->entryPoint);
+ DAWN_INVALID_IF(depthStencil == nullptr,
+ "Depth stencil state is not present when fragment stage (%s, %s) is "
+ "writing to frag_depth.",
+ descriptor->module, &entryPoint);
const Format* depthStencilFormat;
DAWN_TRY_ASSIGN(depthStencilFormat, device->GetInternalFormat(depthStencil->format));
DAWN_INVALID_IF(!depthStencilFormat->HasDepth(),
"Depth stencil state format (%s) has no depth aspect when fragment stage "
- "(%s, entryPoint: %s) is "
+ "(%s, %s) is "
"writing to frag_depth.",
- depthStencil->format, descriptor->module, descriptor->entryPoint);
+ depthStencil->format, descriptor->module, &entryPoint);
}
uint32_t maxColorAttachments = device->GetLimits().v1.maxColorAttachments;
@@ -639,9 +643,8 @@
if (device->IsCompatibilityMode()) {
DAWN_INVALID_IF(
fragmentMetadata.usesSampleMaskOutput,
- "sample_mask is not supported in compatibility mode in the fragment stage (%s, "
- "entryPoint: %s)",
- descriptor->module, descriptor->entryPoint);
+ "sample_mask is not supported in compatibility mode in the fragment stage (%s, %s)",
+ descriptor->module, &entryPoint);
// Check that all the color target states match.
ColorAttachmentIndex firstColorTargetIndex{};
@@ -659,16 +662,18 @@
}
}
- return {};
+ return entryPoint;
}
MaybeError ValidateInterStageMatching(DeviceBase* device,
const VertexState& vertexState,
- const FragmentState& fragmentState) {
+ const ShaderModuleEntryPoint& vertexEntryPoint,
+ const FragmentState& fragmentState,
+ const ShaderModuleEntryPoint& fragmentEntryPoint) {
const EntryPointMetadata& vertexMetadata =
- vertexState.module->GetEntryPoint(vertexState.entryPoint);
+ vertexState.module->GetEntryPoint(vertexEntryPoint.name);
const EntryPointMetadata& fragmentMetadata =
- fragmentState.module->GetEntryPoint(fragmentState.entryPoint);
+ fragmentState.module->GetEntryPoint(fragmentEntryPoint.name);
size_t maxInterStageShaderVariables = device->GetLimits().v1.maxInterStageShaderVariables;
DAWN_ASSERT(vertexMetadata.usedInterStageVariables.size() == maxInterStageShaderVariables);
@@ -744,9 +749,11 @@
DAWN_TRY(device->ValidateObject(descriptor->layout));
}
- DAWN_TRY_CONTEXT(ValidateVertexState(device, &descriptor->vertex, descriptor->layout,
- descriptor->primitive.topology),
- "validating vertex state.");
+ ShaderModuleEntryPoint vertexEntryPoint;
+ DAWN_TRY_ASSIGN_CONTEXT(vertexEntryPoint,
+ ValidateVertexState(device, &descriptor->vertex, descriptor->layout,
+ descriptor->primitive.topology),
+ "validating vertex state.");
DAWN_TRY_CONTEXT(ValidatePrimitiveState(device, &descriptor->primitive),
"validating primitive state.");
@@ -764,9 +771,12 @@
"alphaToCoverageEnabled is true when fragment state is not present.");
if (descriptor->fragment != nullptr) {
- DAWN_TRY_CONTEXT(ValidateFragmentState(device, descriptor->fragment, descriptor->layout,
- descriptor->depthStencil, descriptor->multisample),
- "validating fragment state.");
+ ShaderModuleEntryPoint fragmentEntryPoint;
+ DAWN_TRY_ASSIGN_CONTEXT(
+ fragmentEntryPoint,
+ ValidateFragmentState(device, descriptor->fragment, descriptor->layout,
+ descriptor->depthStencil, descriptor->multisample),
+ "validating fragment state.");
bool hasStorageAttachments =
descriptor->layout != nullptr && descriptor->layout->HasAnyStorageAttachments();
@@ -774,7 +784,8 @@
!hasStorageAttachments,
"No attachment was specified (color, depth-stencil or other).");
- DAWN_TRY(ValidateInterStageMatching(device, descriptor->vertex, *(descriptor->fragment)));
+ DAWN_TRY(ValidateInterStageMatching(device, descriptor->vertex, vertexEntryPoint,
+ *(descriptor->fragment), fragmentEntryPoint));
}
return {};
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index b119d8e..3e55c47 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -1263,6 +1263,19 @@
return mEntryPoints.count(entryPoint) > 0;
}
+ShaderModuleEntryPoint ShaderModuleBase::ReifyEntryPointName(const char* entryPointName,
+ SingleShaderStage stage) const {
+ ShaderModuleEntryPoint entryPoint;
+ if (entryPointName) {
+ entryPoint.defaulted = false;
+ entryPoint.name = entryPointName;
+ } else {
+ entryPoint.defaulted = true;
+ entryPoint.name = mDefaultEntryPointNames[stage];
+ }
+ return entryPoint;
+}
+
const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const {
DAWN_ASSERT(HasEntryPoint(entryPoint));
return *mEntryPoints.at(entryPoint);
@@ -1339,6 +1352,18 @@
DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
&mEntryPoints, &mEnabledWGSLExtensions));
+
+ for (auto stage : IterateStages(kAllStages)) {
+ mEntryPointCounts[stage] = 0;
+ }
+ for (auto& [name, metadata] : mEntryPoints) {
+ SingleShaderStage stage = metadata->stage;
+ if (mEntryPointCounts[stage] == 0) {
+ mDefaultEntryPointNames[stage] = name;
+ }
+ mEntryPointCounts[stage]++;
+ }
+
return {};
}
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 1df1170..86fc3ca 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -120,6 +120,11 @@
std::unique_ptr<TintSource> tintSource;
};
+struct ShaderModuleEntryPoint {
+ bool defaulted;
+ std::string name;
+};
+
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
@@ -298,6 +303,13 @@
// Return true iff the program has an entrypoint called `entryPoint`.
bool HasEntryPoint(const std::string& entryPoint) const;
+ // Return the number of entry points for a stage.
+ size_t GetEntryPointCount(SingleShaderStage stage) const { return mEntryPointCounts[stage]; }
+
+ // Return the entry point for a stage. If no entry point name, returns the default one.
+ ShaderModuleEntryPoint ReifyEntryPointName(const char* entryPointName,
+ SingleShaderStage stage) const;
+
// Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
// must be true.
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;
@@ -334,6 +346,8 @@
std::string mWgsl;
EntryPointMetadataTable mEntryPoints;
+ PerStage<std::string> mDefaultEntryPointNames;
+ PerStage<size_t> mEntryPointCounts;
WGSLExtensionSet mEnabledWGSLExtensions;
std::unique_ptr<tint::Program> mTintProgram;
std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive
diff --git a/src/dawn/native/webgpu_absl_format.cpp b/src/dawn/native/webgpu_absl_format.cpp
index a2dd19c..1abc5c7 100644
--- a/src/dawn/native/webgpu_absl_format.cpp
+++ b/src/dawn/native/webgpu_absl_format.cpp
@@ -165,6 +165,22 @@
return {true};
}
+absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert(
+ const ShaderModuleEntryPoint* value,
+ const absl::FormatConversionSpec& spec,
+ absl::FormatSink* s) {
+ if (value == nullptr) {
+ s->Append("[null]");
+ return {true};
+ }
+ s->Append(absl::StrFormat("[EntryPoint \"%s\"", value->name));
+ if (value->defaulted) {
+ s->Append(" (defaulted)");
+ }
+ s->Append("]");
+ return {true};
+}
+
//
// Objects
//
diff --git a/src/dawn/native/webgpu_absl_format.h b/src/dawn/native/webgpu_absl_format.h
index d1dcef6..e8f9da8 100644
--- a/src/dawn/native/webgpu_absl_format.h
+++ b/src/dawn/native/webgpu_absl_format.h
@@ -94,6 +94,12 @@
const absl::FormatConversionSpec& spec,
absl::FormatSink* s);
+struct ShaderModuleEntryPoint;
+absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert(
+ const ShaderModuleEntryPoint* value,
+ const absl::FormatConversionSpec& spec,
+ absl::FormatSink* s);
+
//
// Objects
//
diff --git a/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp b/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
index bd1c6f2..bf8be6f 100644
--- a/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
@@ -101,5 +101,56 @@
ASSERT_DEVICE_ERROR(TestDispatch(max + 1, max + 1, max + 1));
}
+class ComputeValidationEntryPointTest : public ValidationTest {};
+
+// Check that entry points are optional.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameOptional) {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @compute @workgroup_size(1) fn main() {}
+ )");
+
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+ csDesc.compute.module = module;
+ csDesc.compute.entryPoint = nullptr;
+
+ device.CreateComputePipeline(&csDesc);
+
+ csDesc.layout = nullptr;
+ device.CreateComputePipeline(&csDesc);
+}
+
+// Check that entry points are required if module has multiple entry points.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameRequiredIfMultipleEntryPoints) {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @compute @workgroup_size(1) fn main1() {}
+ @compute @workgroup_size(1) fn main2() {}
+ )");
+
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+ csDesc.compute.module = module;
+ csDesc.compute.entryPoint = "main1";
+
+ device.CreateComputePipeline(&csDesc);
+
+ csDesc.compute.entryPoint = "nullptr";
+ ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+}
+
+// Check that entry points are required if module has no compatible entry points.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameRequiredIfNoCompatibleEntryPoints) {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @fragment fn main() {}
+ )");
+
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+ csDesc.compute.module = module;
+ csDesc.compute.entryPoint = nullptr;
+
+ ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+}
+
} // anonymous namespace
} // namespace dawn
diff --git a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
index 0c2d853..da9c81a 100644
--- a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -1416,6 +1416,100 @@
}
}
+// Check that entry points are optional.
+TEST_F(RenderPipelineValidationTest, EntryPointNameOptional) {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @vertex fn vertex_main() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+ }
+
+ @fragment fn fragment_main() -> @location(0) vec4f {
+ return vec4f(1.0, 0.0, 0.0, 1.0);
+ }
+ )");
+
+ utils::ComboRenderPipelineDescriptor descriptor;
+ descriptor.vertex.module = module;
+ descriptor.vertex.entryPoint = nullptr;
+ descriptor.cFragment.module = module;
+ descriptor.cFragment.entryPoint = nullptr;
+
+ // Success case.
+ device.CreateRenderPipeline(&descriptor);
+}
+
+// Check that entry points are required if module has multiple entry points.
+TEST_F(RenderPipelineValidationTest, EntryPointNameRequiredIfMultipleEntryPoints) {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @vertex fn vertex1() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+ }
+
+ @vertex fn vertex2() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+ }
+
+ @fragment fn fragment1() -> @location(0) vec4f {
+ return vec4f(1.0, 0.0, 0.0, 1.0);
+ }
+
+ @fragment fn fragment2() -> @location(0) vec4f {
+ return vec4f(1.0, 0.0, 0.0, 1.0);
+ }
+ )");
+
+ utils::ComboRenderPipelineDescriptor descriptor;
+ descriptor.vertex.module = module;
+ descriptor.cFragment.module = module;
+
+ {
+ // The vertex stage has more than one entryPoint.
+ descriptor.vertex.entryPoint = nullptr;
+ descriptor.cFragment.entryPoint = "fragment1";
+ ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ }
+
+ {
+ // The fragment stage has more than one entryPoint.
+ descriptor.vertex.entryPoint = "vertex1";
+ descriptor.cFragment.entryPoint = nullptr;
+ ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ }
+}
+
+// Check that entry points are required if module has no compatible entry points.
+TEST_F(RenderPipelineValidationTest, EntryPointNameRequiredIfNoCompatibleEntryPoints) {
+ {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @fragment fn fragment_main() -> @location(0) vec4f {
+ return vec4f(1.0, 0.0, 0.0, 1.0);
+ }
+ )");
+
+ utils::ComboRenderPipelineDescriptor descriptor;
+ descriptor.vertex.module = module;
+ descriptor.cFragment.module = module;
+
+ // The vertex stage has no entryPoint.
+ ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ }
+
+ {
+ wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ @vertex fn vertex_main() -> @builtin(position) vec4f {
+ return vec4f(0.0, 0.0, 0.0, 1.0);
+ }
+ )");
+
+ utils::ComboRenderPipelineDescriptor descriptor;
+ descriptor.vertex.module = module;
+ descriptor.cFragment.module = module;
+
+ // The fragment stage has no entryPoint.
+ ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+ }
+}
+
// Test that vertex attrib validation is for the correct entryPoint
TEST_F(RenderPipelineValidationTest, VertexAttribCorrectEntryPoint) {
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(