Implement entryPoint defaulting

Spec: https://github.com/gpuweb/gpuweb/pull/4387

Bug: dawn:2254
Change-Id: I51c67253b095f5daf59dac378624b6bc38fe524d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/161901
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Fr <beaufort.francois@gmail.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/dawn.json b/dawn.json
index 7a80ae8..fc76c12 100644
--- a/dawn.json
+++ b/dawn.json
@@ -2228,7 +2228,7 @@
         "extensible": "in",
         "members": [
             {"name": "module", "type": "shader module"},
-            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
             {"name": "constant count", "type": "size_t", "default": 0},
             {"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"}
         ]
@@ -2832,7 +2832,7 @@
         "extensible": "in",
         "members": [
             {"name": "module", "type": "shader module"},
-            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
             {"name": "constant count", "type": "size_t", "default": 0},
             {"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"},
             {"name": "buffer count", "type": "size_t", "default": 0},
@@ -2912,7 +2912,7 @@
         "extensible": "in",
         "members": [
             {"name": "module", "type": "shader module"},
-            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"},
+            {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen", "optional": true},
             {"name": "constant count", "type": "size_t", "default": 0},
             {"name": "constants", "type": "constant entry", "annotation": "const*", "length": "constant count"},
             {"name": "target count", "type": "size_t"},
diff --git a/src/dawn/native/ComputePipeline.cpp b/src/dawn/native/ComputePipeline.cpp
index cbb5cb6..b3b99d9 100644
--- a/src/dawn/native/ComputePipeline.cpp
+++ b/src/dawn/native/ComputePipeline.cpp
@@ -43,12 +43,14 @@
         DAWN_TRY(device->ValidateObject(descriptor->layout));
     }
 
-    DAWN_TRY_CONTEXT(ValidateProgrammableStage(
-                         device, descriptor->compute.module, descriptor->compute.entryPoint,
-                         descriptor->compute.constantCount, descriptor->compute.constants,
-                         descriptor->layout, SingleShaderStage::Compute),
-                     "validating compute stage (%s, entryPoint: %s).", descriptor->compute.module,
-                     descriptor->compute.entryPoint);
+    ShaderModuleEntryPoint entryPoint;
+    DAWN_TRY_ASSIGN_CONTEXT(entryPoint,
+                            ValidateProgrammableStage(
+                                device, descriptor->compute.module, descriptor->compute.entryPoint,
+                                descriptor->compute.constantCount, descriptor->compute.constants,
+                                descriptor->layout, SingleShaderStage::Compute),
+                            "validating compute stage (%s, entryPoint: %s).",
+                            descriptor->compute.module, descriptor->compute.entryPoint);
     return {};
 }
 
diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp
index 347c738..ff1f5b9 100644
--- a/src/dawn/native/Pipeline.cpp
+++ b/src/dawn/native/Pipeline.cpp
@@ -47,33 +47,46 @@
 }  // namespace
 
 namespace dawn::native {
-MaybeError ValidateProgrammableStage(DeviceBase* device,
-                                     const ShaderModuleBase* module,
-                                     const std::string& entryPoint,
-                                     uint32_t constantCount,
-                                     const ConstantEntry* constants,
-                                     const PipelineLayoutBase* layout,
-                                     SingleShaderStage stage) {
+ResultOrError<ShaderModuleEntryPoint> ValidateProgrammableStage(DeviceBase* device,
+                                                                const ShaderModuleBase* module,
+                                                                const char* entryPointName,
+                                                                uint32_t constantCount,
+                                                                const ConstantEntry* constants,
+                                                                const PipelineLayoutBase* layout,
+                                                                SingleShaderStage stage) {
     DAWN_TRY(device->ValidateObject(module));
 
-    DAWN_INVALID_IF(!module->HasEntryPoint(entryPoint),
-                    "Entry point \"%s\" doesn't exist in the shader module %s.", entryPoint,
-                    module);
+    if (entryPointName) {
+        DAWN_INVALID_IF(!module->HasEntryPoint(entryPointName),
+                        "Entry point \"%s\" doesn't exist in the shader module %s.", entryPointName,
+                        module);
+    } else {
+        size_t entryPointCount = module->GetEntryPointCount(stage);
+        if (entryPointCount == 0) {
+            return DAWN_VALIDATION_ERROR(
+                "Compatible entry point for stage (%s) doesn't exist in the shader module %s.",
+                stage, module);
+        } else if (entryPointCount > 1) {
+            return DAWN_VALIDATION_ERROR(
+                "Multiple entry points for stage (%s) exist in the shader module %s.", stage,
+                module);
+        }
+    }
 
-    const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint);
+    ShaderModuleEntryPoint entryPoint = module->ReifyEntryPointName(entryPointName, stage);
+    const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint.name);
 
     if (!metadata.infringedLimitErrors.empty()) {
         std::ostringstream limitList;
         for (const std::string& limit : metadata.infringedLimitErrors) {
             limitList << " - " << limit << "\n";
         }
-        return DAWN_VALIDATION_ERROR("Entry point \"%s\" infringes limits:\n%s", entryPoint,
-                                     limitList.str());
+        return DAWN_VALIDATION_ERROR("%s infringes limits:\n%s", &entryPoint, limitList.str());
     }
 
     DAWN_INVALID_IF(metadata.stage != stage,
                     "The stage (%s) of the entry point \"%s\" isn't the expected one (%s).",
-                    metadata.stage, entryPoint, stage);
+                    metadata.stage, entryPoint.name, stage);
 
     if (layout != nullptr) {
         DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
@@ -163,7 +176,7 @@
             uninitializedConstantsArray);
     }
 
-    return {};
+    return entryPoint;
 }
 
 WGPUCreatePipelineAsyncStatus CreatePipelineAsyncStatusFromErrorType(InternalErrorType error) {
diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h
index 7bac9f7..c2234d2 100644
--- a/src/dawn/native/Pipeline.h
+++ b/src/dawn/native/Pipeline.h
@@ -45,13 +45,13 @@
 
 namespace dawn::native {
 
-MaybeError ValidateProgrammableStage(DeviceBase* device,
-                                     const ShaderModuleBase* module,
-                                     const std::string& entryPoint,
-                                     uint32_t constantCount,
-                                     const ConstantEntry* constants,
-                                     const PipelineLayoutBase* layout,
-                                     SingleShaderStage stage);
+ResultOrError<ShaderModuleEntryPoint> ValidateProgrammableStage(DeviceBase* device,
+                                                                const ShaderModuleBase* module,
+                                                                const char* entryPointName,
+                                                                uint32_t constantCount,
+                                                                const ConstantEntry* constants,
+                                                                const PipelineLayoutBase* layout,
+                                                                SingleShaderStage stage);
 
 WGPUCreatePipelineAsyncStatus CreatePipelineAsyncStatusFromErrorType(InternalErrorType error);
 
diff --git a/src/dawn/native/PipelineLayout.cpp b/src/dawn/native/PipelineLayout.cpp
index 4a4180c..165a070 100644
--- a/src/dawn/native/PipelineLayout.cpp
+++ b/src/dawn/native/PipelineLayout.cpp
@@ -92,6 +92,17 @@
     return {};
 }
 
+StageAndDescriptor::StageAndDescriptor(SingleShaderStage shaderStage,
+                                       ShaderModuleBase* module,
+                                       const char* entryPoint,
+                                       size_t constantCount,
+                                       ConstantEntry const* constants)
+    : shaderStage(shaderStage),
+      module(module),
+      entryPoint(module->ReifyEntryPointName(entryPoint, shaderStage).name),
+      constantCount(constantCount),
+      constants(constants) {}
+
 // PipelineLayoutBase
 
 PipelineLayoutBase::PipelineLayoutBase(DeviceBase* device,
diff --git a/src/dawn/native/PipelineLayout.h b/src/dawn/native/PipelineLayout.h
index 756f8bb..8f7dcbd 100644
--- a/src/dawn/native/PipelineLayout.h
+++ b/src/dawn/native/PipelineLayout.h
@@ -56,6 +56,12 @@
 using BindGroupLayoutMask = ityp::bitset<BindGroupIndex, kMaxBindGroups>;
 
 struct StageAndDescriptor {
+    StageAndDescriptor(SingleShaderStage shaderStage,
+                       ShaderModuleBase* module,
+                       const char* entryPoint,
+                       size_t constantCount,
+                       ConstantEntry const* constants);
+
     SingleShaderStage shaderStage;
     ShaderModuleBase* module;
     std::string entryPoint;
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index 8286dd6..fdf2f1c 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -173,10 +173,11 @@
     return {};
 }
 
-MaybeError ValidateVertexState(DeviceBase* device,
-                               const VertexState* descriptor,
-                               const PipelineLayoutBase* layout,
-                               wgpu::PrimitiveTopology primitiveTopology) {
+ResultOrError<ShaderModuleEntryPoint> ValidateVertexState(
+    DeviceBase* device,
+    const VertexState* descriptor,
+    const PipelineLayoutBase* layout,
+    wgpu::PrimitiveTopology primitiveTopology) {
     DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");
 
     const CombinedLimits& limits = device->GetLimits();
@@ -185,13 +186,15 @@
                     "Vertex buffer count (%u) exceeds the maximum number of vertex buffers (%u).",
                     descriptor->bufferCount, limits.v1.maxVertexBuffers);
 
-    DAWN_TRY_CONTEXT(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
-                                               descriptor->constantCount, descriptor->constants,
-                                               layout, SingleShaderStage::Vertex),
-                     "validating vertex stage (%s, entryPoint: %s).", descriptor->module,
-                     descriptor->entryPoint);
-    const EntryPointMetadata& vertexMetadata =
-        descriptor->module->GetEntryPoint(descriptor->entryPoint);
+    ShaderModuleEntryPoint entryPoint;
+    DAWN_TRY_ASSIGN_CONTEXT(
+        entryPoint,
+        ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+                                  descriptor->constantCount, descriptor->constants, layout,
+                                  SingleShaderStage::Vertex),
+        "validating vertex stage (%s, entryPoint: %s).", descriptor->module,
+        descriptor->entryPoint);
+    const EntryPointMetadata& vertexMetadata = descriptor->module->GetEntryPoint(entryPoint.name);
     if (primitiveTopology == wgpu::PrimitiveTopology::PointList) {
         DAWN_INVALID_IF(
             vertexMetadata.totalInterStageShaderComponents + 1 >
@@ -239,12 +242,12 @@
         VertexAttributeLocation firstMissing = ityp::Sub(
             GetHighestBitIndexPlusOne(missingAttributes), VertexAttributeLocation(uint8_t(1)));
         return DAWN_VALIDATION_ERROR(
-            "Vertex attribute slot %u used in (%s, entryPoint: %s) is not present in the "
+            "Vertex attribute slot %u used in (%s, %s) is not present in the "
             "VertexState.",
-            uint8_t(firstMissing), descriptor->module, descriptor->entryPoint);
+            uint8_t(firstMissing), descriptor->module, &entryPoint);
     }
 
-    return {};
+    return entryPoint;
 }
 
 MaybeError ValidatePrimitiveState(const DeviceBase* device, const PrimitiveState* descriptor) {
@@ -543,35 +546,36 @@
     return {};
 }
 
-MaybeError ValidateFragmentState(DeviceBase* device,
-                                 const FragmentState* descriptor,
-                                 const PipelineLayoutBase* layout,
-                                 const DepthStencilState* depthStencil,
-                                 const MultisampleState& multisample) {
+ResultOrError<ShaderModuleEntryPoint> ValidateFragmentState(DeviceBase* device,
+                                                            const FragmentState* descriptor,
+                                                            const PipelineLayoutBase* layout,
+                                                            const DepthStencilState* depthStencil,
+                                                            const MultisampleState& multisample) {
     DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");
 
-    DAWN_TRY_CONTEXT(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
-                                               descriptor->constantCount, descriptor->constants,
-                                               layout, SingleShaderStage::Fragment),
-                     "validating fragment stage (%s, entryPoint: %s).", descriptor->module,
-                     descriptor->entryPoint);
+    ShaderModuleEntryPoint entryPoint;
+    DAWN_TRY_ASSIGN_CONTEXT(
+        entryPoint,
+        ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
+                                  descriptor->constantCount, descriptor->constants, layout,
+                                  SingleShaderStage::Fragment),
+        "validating fragment stage (%s, entryPoint: %s).", descriptor->module,
+        descriptor->entryPoint);
 
-    const EntryPointMetadata& fragmentMetadata =
-        descriptor->module->GetEntryPoint(descriptor->entryPoint);
+    const EntryPointMetadata& fragmentMetadata = descriptor->module->GetEntryPoint(entryPoint.name);
 
     if (fragmentMetadata.usesFragDepth) {
-        DAWN_INVALID_IF(
-            depthStencil == nullptr,
-            "Depth stencil state is not present when fragment stage (%s, entryPoint: %s) is "
-            "writing to frag_depth.",
-            descriptor->module, descriptor->entryPoint);
+        DAWN_INVALID_IF(depthStencil == nullptr,
+                        "Depth stencil state is not present when fragment stage (%s, %s) is "
+                        "writing to frag_depth.",
+                        descriptor->module, &entryPoint);
         const Format* depthStencilFormat;
         DAWN_TRY_ASSIGN(depthStencilFormat, device->GetInternalFormat(depthStencil->format));
         DAWN_INVALID_IF(!depthStencilFormat->HasDepth(),
                         "Depth stencil state format (%s) has no depth aspect when fragment stage "
-                        "(%s, entryPoint: %s) is "
+                        "(%s, %s) is "
                         "writing to frag_depth.",
-                        depthStencil->format, descriptor->module, descriptor->entryPoint);
+                        depthStencil->format, descriptor->module, &entryPoint);
     }
 
     uint32_t maxColorAttachments = device->GetLimits().v1.maxColorAttachments;
@@ -639,9 +643,8 @@
     if (device->IsCompatibilityMode()) {
         DAWN_INVALID_IF(
             fragmentMetadata.usesSampleMaskOutput,
-            "sample_mask is not supported in compatibility mode in the fragment stage (%s, "
-            "entryPoint: %s)",
-            descriptor->module, descriptor->entryPoint);
+            "sample_mask is not supported in compatibility mode in the fragment stage (%s, %s)",
+            descriptor->module, &entryPoint);
 
         // Check that all the color target states match.
         ColorAttachmentIndex firstColorTargetIndex{};
@@ -659,16 +662,18 @@
         }
     }
 
-    return {};
+    return entryPoint;
 }
 
 MaybeError ValidateInterStageMatching(DeviceBase* device,
                                       const VertexState& vertexState,
-                                      const FragmentState& fragmentState) {
+                                      const ShaderModuleEntryPoint& vertexEntryPoint,
+                                      const FragmentState& fragmentState,
+                                      const ShaderModuleEntryPoint& fragmentEntryPoint) {
     const EntryPointMetadata& vertexMetadata =
-        vertexState.module->GetEntryPoint(vertexState.entryPoint);
+        vertexState.module->GetEntryPoint(vertexEntryPoint.name);
     const EntryPointMetadata& fragmentMetadata =
-        fragmentState.module->GetEntryPoint(fragmentState.entryPoint);
+        fragmentState.module->GetEntryPoint(fragmentEntryPoint.name);
 
     size_t maxInterStageShaderVariables = device->GetLimits().v1.maxInterStageShaderVariables;
     DAWN_ASSERT(vertexMetadata.usedInterStageVariables.size() == maxInterStageShaderVariables);
@@ -744,9 +749,11 @@
         DAWN_TRY(device->ValidateObject(descriptor->layout));
     }
 
-    DAWN_TRY_CONTEXT(ValidateVertexState(device, &descriptor->vertex, descriptor->layout,
-                                         descriptor->primitive.topology),
-                     "validating vertex state.");
+    ShaderModuleEntryPoint vertexEntryPoint;
+    DAWN_TRY_ASSIGN_CONTEXT(vertexEntryPoint,
+                            ValidateVertexState(device, &descriptor->vertex, descriptor->layout,
+                                                descriptor->primitive.topology),
+                            "validating vertex state.");
 
     DAWN_TRY_CONTEXT(ValidatePrimitiveState(device, &descriptor->primitive),
                      "validating primitive state.");
@@ -764,9 +771,12 @@
         "alphaToCoverageEnabled is true when fragment state is not present.");
 
     if (descriptor->fragment != nullptr) {
-        DAWN_TRY_CONTEXT(ValidateFragmentState(device, descriptor->fragment, descriptor->layout,
-                                               descriptor->depthStencil, descriptor->multisample),
-                         "validating fragment state.");
+        ShaderModuleEntryPoint fragmentEntryPoint;
+        DAWN_TRY_ASSIGN_CONTEXT(
+            fragmentEntryPoint,
+            ValidateFragmentState(device, descriptor->fragment, descriptor->layout,
+                                  descriptor->depthStencil, descriptor->multisample),
+            "validating fragment state.");
 
         bool hasStorageAttachments =
             descriptor->layout != nullptr && descriptor->layout->HasAnyStorageAttachments();
@@ -774,7 +784,8 @@
                             !hasStorageAttachments,
                         "No attachment was specified (color, depth-stencil or other).");
 
-        DAWN_TRY(ValidateInterStageMatching(device, descriptor->vertex, *(descriptor->fragment)));
+        DAWN_TRY(ValidateInterStageMatching(device, descriptor->vertex, vertexEntryPoint,
+                                            *(descriptor->fragment), fragmentEntryPoint));
     }
 
     return {};
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index b119d8e..3e55c47 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -1263,6 +1263,19 @@
     return mEntryPoints.count(entryPoint) > 0;
 }
 
+ShaderModuleEntryPoint ShaderModuleBase::ReifyEntryPointName(const char* entryPointName,
+                                                             SingleShaderStage stage) const {
+    ShaderModuleEntryPoint entryPoint;
+    if (entryPointName) {
+        entryPoint.defaulted = false;
+        entryPoint.name = entryPointName;
+    } else {
+        entryPoint.defaulted = true;
+        entryPoint.name = mDefaultEntryPointNames[stage];
+    }
+    return entryPoint;
+}
+
 const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const {
     DAWN_ASSERT(HasEntryPoint(entryPoint));
     return *mEntryPoints.at(entryPoint);
@@ -1339,6 +1352,18 @@
 
     DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
                                     &mEntryPoints, &mEnabledWGSLExtensions));
+
+    for (auto stage : IterateStages(kAllStages)) {
+        mEntryPointCounts[stage] = 0;
+    }
+    for (auto& [name, metadata] : mEntryPoints) {
+        SingleShaderStage stage = metadata->stage;
+        if (mEntryPointCounts[stage] == 0) {
+            mDefaultEntryPointNames[stage] = name;
+        }
+        mEntryPointCounts[stage]++;
+    }
+
     return {};
 }
 
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index 1df1170..86fc3ca 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -120,6 +120,11 @@
     std::unique_ptr<TintSource> tintSource;
 };
 
+struct ShaderModuleEntryPoint {
+    bool defaulted;
+    std::string name;
+};
+
 MaybeError ValidateAndParseShaderModule(DeviceBase* device,
                                         const ShaderModuleDescriptor* descriptor,
                                         ShaderModuleParseResult* parseResult,
@@ -298,6 +303,13 @@
     // Return true iff the program has an entrypoint called `entryPoint`.
     bool HasEntryPoint(const std::string& entryPoint) const;
 
+    // Return the number of entry points for a stage.
+    size_t GetEntryPointCount(SingleShaderStage stage) const { return mEntryPointCounts[stage]; }
+
+    // Return the entry point for a stage. If no entry point name, returns the default one.
+    ShaderModuleEntryPoint ReifyEntryPointName(const char* entryPointName,
+                                               SingleShaderStage stage) const;
+
     // Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
     // must be true.
     const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;
@@ -334,6 +346,8 @@
     std::string mWgsl;
 
     EntryPointMetadataTable mEntryPoints;
+    PerStage<std::string> mDefaultEntryPointNames;
+    PerStage<size_t> mEntryPointCounts;
     WGSLExtensionSet mEnabledWGSLExtensions;
     std::unique_ptr<tint::Program> mTintProgram;
     std::unique_ptr<TintSource> mTintSource;  // Keep the tint::Source::File alive
diff --git a/src/dawn/native/webgpu_absl_format.cpp b/src/dawn/native/webgpu_absl_format.cpp
index a2dd19c..1abc5c7 100644
--- a/src/dawn/native/webgpu_absl_format.cpp
+++ b/src/dawn/native/webgpu_absl_format.cpp
@@ -165,6 +165,22 @@
     return {true};
 }
 
+absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert(
+    const ShaderModuleEntryPoint* value,
+    const absl::FormatConversionSpec& spec,
+    absl::FormatSink* s) {
+    if (value == nullptr) {
+        s->Append("[null]");
+        return {true};
+    }
+    s->Append(absl::StrFormat("[EntryPoint \"%s\"", value->name));
+    if (value->defaulted) {
+        s->Append(" (defaulted)");
+    }
+    s->Append("]");
+    return {true};
+}
+
 //
 // Objects
 //
diff --git a/src/dawn/native/webgpu_absl_format.h b/src/dawn/native/webgpu_absl_format.h
index d1dcef6..e8f9da8 100644
--- a/src/dawn/native/webgpu_absl_format.h
+++ b/src/dawn/native/webgpu_absl_format.h
@@ -94,6 +94,12 @@
     const absl::FormatConversionSpec& spec,
     absl::FormatSink* s);
 
+struct ShaderModuleEntryPoint;
+absl::FormatConvertResult<absl::FormatConversionCharSet::kString> AbslFormatConvert(
+    const ShaderModuleEntryPoint* value,
+    const absl::FormatConversionSpec& spec,
+    absl::FormatSink* s);
+
 //
 // Objects
 //
diff --git a/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp b/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
index bd1c6f2..bf8be6f 100644
--- a/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/ComputeValidationTests.cpp
@@ -101,5 +101,56 @@
     ASSERT_DEVICE_ERROR(TestDispatch(max + 1, max + 1, max + 1));
 }
 
+class ComputeValidationEntryPointTest : public ValidationTest {};
+
+// Check that entry points are optional.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameOptional) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @compute @workgroup_size(1) fn main() {}
+    )");
+
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+    csDesc.compute.module = module;
+    csDesc.compute.entryPoint = nullptr;
+
+    device.CreateComputePipeline(&csDesc);
+
+    csDesc.layout = nullptr;
+    device.CreateComputePipeline(&csDesc);
+}
+
+// Check that entry points are required if module has multiple entry points.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameRequiredIfMultipleEntryPoints) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @compute @workgroup_size(1) fn main1() {}
+        @compute @workgroup_size(1) fn main2() {}
+    )");
+
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+    csDesc.compute.module = module;
+    csDesc.compute.entryPoint = "main1";
+
+    device.CreateComputePipeline(&csDesc);
+
+    csDesc.compute.entryPoint = "nullptr";
+    ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+}
+
+// Check that entry points are required if module has no compatible entry points.
+TEST_F(ComputeValidationEntryPointTest, EntryPointNameRequiredIfNoCompatibleEntryPoints) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @fragment fn main() {}
+    )");
+
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.layout = utils::MakeBasicPipelineLayout(device, nullptr);
+    csDesc.compute.module = module;
+    csDesc.compute.entryPoint = nullptr;
+
+    ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+}
+
 }  // anonymous namespace
 }  // namespace dawn
diff --git a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
index 0c2d853..da9c81a 100644
--- a/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -1416,6 +1416,100 @@
     }
 }
 
+// Check that entry points are optional.
+TEST_F(RenderPipelineValidationTest, EntryPointNameOptional) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @vertex fn vertex_main() -> @builtin(position) vec4f {
+            return vec4f(0.0, 0.0, 0.0, 1.0);
+        }
+
+        @fragment fn fragment_main() -> @location(0) vec4f {
+            return vec4f(1.0, 0.0, 0.0, 1.0);
+        }
+    )");
+
+    utils::ComboRenderPipelineDescriptor descriptor;
+    descriptor.vertex.module = module;
+    descriptor.vertex.entryPoint = nullptr;
+    descriptor.cFragment.module = module;
+    descriptor.cFragment.entryPoint = nullptr;
+
+    // Success case.
+    device.CreateRenderPipeline(&descriptor);
+}
+
+// Check that entry points are required if module has multiple entry points.
+TEST_F(RenderPipelineValidationTest, EntryPointNameRequiredIfMultipleEntryPoints) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        @vertex fn vertex1() -> @builtin(position) vec4f {
+            return vec4f(0.0, 0.0, 0.0, 1.0);
+        }
+
+        @vertex fn vertex2() -> @builtin(position) vec4f {
+            return vec4f(0.0, 0.0, 0.0, 1.0);
+        }
+
+        @fragment fn fragment1() -> @location(0) vec4f {
+            return vec4f(1.0, 0.0, 0.0, 1.0);
+        }
+
+        @fragment fn fragment2() -> @location(0) vec4f {
+            return vec4f(1.0, 0.0, 0.0, 1.0);
+        }
+    )");
+
+    utils::ComboRenderPipelineDescriptor descriptor;
+    descriptor.vertex.module = module;
+    descriptor.cFragment.module = module;
+
+    {
+        // The vertex stage has more than one entryPoint.
+        descriptor.vertex.entryPoint = nullptr;
+        descriptor.cFragment.entryPoint = "fragment1";
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+    }
+
+    {
+        // The fragment stage has more than one entryPoint.
+        descriptor.vertex.entryPoint = "vertex1";
+        descriptor.cFragment.entryPoint = nullptr;
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+    }
+}
+
+// Check that entry points are required if module has no compatible entry points.
+TEST_F(RenderPipelineValidationTest, EntryPointNameRequiredIfNoCompatibleEntryPoints) {
+    {
+        wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+            @fragment fn fragment_main() -> @location(0) vec4f {
+                return vec4f(1.0, 0.0, 0.0, 1.0);
+            }
+        )");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = module;
+        descriptor.cFragment.module = module;
+
+        // The vertex stage has no entryPoint.
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+    }
+
+    {
+        wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+            @vertex fn vertex_main() -> @builtin(position) vec4f {
+                return vec4f(0.0, 0.0, 0.0, 1.0);
+            }
+        )");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = module;
+        descriptor.cFragment.module = module;
+
+        // The fragment stage has no entryPoint.
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+    }
+}
+
 // Test that vertex attrib validation is for the correct entryPoint
 TEST_F(RenderPipelineValidationTest, VertexAttribCorrectEntryPoint) {
     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(