Return an error ComputePassEncoder when error occurs in BeginComputePass

This patch is a follow-up of the descriptorization of render pass
descriptor. In this patch we changes the return value of
BeginComputePass from nullptr to an error compute pass encoder when
there is any error in BeginComputePass() to keep it consistent with what
we do in BeginRenderPass().

This patch also provides functions to create error render/compute pass
encoders. With this patch we can create a pass encoder in error by
specifying ErrorTag in the constructor, which is more staightforward
and human readable than the current implementation.

BUG=dawn:6

Change-Id: I1899ae65804f8cecd3079dc313e7e18acb88e37c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/5140
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/CommandEncoder.cpp b/src/dawn_native/CommandEncoder.cpp
index ef186a1..89297eb 100644
--- a/src/dawn_native/CommandEncoder.cpp
+++ b/src/dawn_native/CommandEncoder.cpp
@@ -486,29 +486,28 @@
     // Implementation of the API's command recording methods
 
     ComputePassEncoderBase* CommandEncoderBase::BeginComputePass() {
+        DeviceBase* device = GetDevice();
         if (ConsumedError(ValidateCanRecordTopLevelCommands())) {
-            return nullptr;
+            return ComputePassEncoderBase::MakeError(device, this);
         }
 
         mAllocator.Allocate<BeginComputePassCmd>(Command::BeginComputePass);
 
         mEncodingState = EncodingState::ComputePass;
-        return new ComputePassEncoderBase(GetDevice(), this, &mAllocator);
+        return new ComputePassEncoderBase(device, this, &mAllocator);
     }
 
     RenderPassEncoderBase* CommandEncoderBase::BeginRenderPass(const RenderPassDescriptor* info) {
         DeviceBase* device = GetDevice();
 
         if (ConsumedError(ValidateCanRecordTopLevelCommands())) {
-            // Using nullptr as allocator will make ValidateCanRecordCommands() always return false,
-            // thus any API call on the return value will result in a Dawn validation error.
-            return new RenderPassEncoderBase(device, this, nullptr);
+            return RenderPassEncoderBase::MakeError(device, this);
         }
 
         uint32_t width = 0;
         uint32_t height = 0;
         if (ConsumedError(ValidateRenderPassDescriptorAndSetSize(device, info, &width, &height))) {
-            return new RenderPassEncoderBase(device, this, nullptr);
+            return RenderPassEncoderBase::MakeError(device, this);
         }
 
         mEncodingState = EncodingState::RenderPass;
@@ -543,7 +542,7 @@
         cmd->width = width;
         cmd->height = height;
 
-        return new RenderPassEncoderBase(GetDevice(), this, &mAllocator);
+        return new RenderPassEncoderBase(device, this, &mAllocator);
     }
 
     void CommandEncoderBase::CopyBufferToBuffer(BufferBase* source,
diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp
index b3c4ed4..406a656 100644
--- a/src/dawn_native/ComputePassEncoder.cpp
+++ b/src/dawn_native/ComputePassEncoder.cpp
@@ -27,6 +27,17 @@
         : ProgrammablePassEncoder(device, topLevelEncoder, allocator) {
     }
 
+    ComputePassEncoderBase::ComputePassEncoderBase(DeviceBase* device,
+                                                   CommandEncoderBase* topLevelEncoder,
+                                                   ErrorTag errorTag)
+        : ProgrammablePassEncoder(device, topLevelEncoder, errorTag) {
+    }
+
+    ComputePassEncoderBase* ComputePassEncoderBase::MakeError(DeviceBase* device,
+                                                              CommandEncoderBase* topLevelEncoder) {
+        return new ComputePassEncoderBase(device, topLevelEncoder, ObjectBase::kError);
+    }
+
     void ComputePassEncoderBase::Dispatch(uint32_t x, uint32_t y, uint32_t z) {
         if (mTopLevelEncoder->ConsumedError(ValidateCanRecordCommands())) {
             return;
diff --git a/src/dawn_native/ComputePassEncoder.h b/src/dawn_native/ComputePassEncoder.h
index b2acad8..a45dad5 100644
--- a/src/dawn_native/ComputePassEncoder.h
+++ b/src/dawn_native/ComputePassEncoder.h
@@ -30,8 +30,16 @@
                                CommandEncoderBase* topLevelEncoder,
                                CommandAllocator* allocator);
 
+        static ComputePassEncoderBase* MakeError(DeviceBase* device,
+                                                 CommandEncoderBase* topLevelEncoder);
+
         void Dispatch(uint32_t x, uint32_t y, uint32_t z);
         void SetPipeline(ComputePipelineBase* pipeline);
+
+      protected:
+        ComputePassEncoderBase(DeviceBase* device,
+                               CommandEncoderBase* topLevelEncoder,
+                               ErrorTag errorTag);
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/ProgrammablePassEncoder.cpp b/src/dawn_native/ProgrammablePassEncoder.cpp
index 143144b..d5eaaa4 100644
--- a/src/dawn_native/ProgrammablePassEncoder.cpp
+++ b/src/dawn_native/ProgrammablePassEncoder.cpp
@@ -27,6 +27,13 @@
                                                      CommandEncoderBase* topLevelEncoder,
                                                      CommandAllocator* allocator)
         : ObjectBase(device), mTopLevelEncoder(topLevelEncoder), mAllocator(allocator) {
+        DAWN_ASSERT(allocator != nullptr);
+    }
+
+    ProgrammablePassEncoder::ProgrammablePassEncoder(DeviceBase* device,
+                                                     CommandEncoderBase* topLevelEncoder,
+                                                     ErrorTag errorTag)
+        : ObjectBase(device, errorTag), mTopLevelEncoder(topLevelEncoder), mAllocator(nullptr) {
     }
 
     void ProgrammablePassEncoder::EndPass() {
diff --git a/src/dawn_native/ProgrammablePassEncoder.h b/src/dawn_native/ProgrammablePassEncoder.h
index 4924c9c..ae7eb46 100644
--- a/src/dawn_native/ProgrammablePassEncoder.h
+++ b/src/dawn_native/ProgrammablePassEncoder.h
@@ -46,6 +46,11 @@
                               const void* data);
 
       protected:
+        // Construct an "error" programmable pass encoder.
+        ProgrammablePassEncoder(DeviceBase* device,
+                                CommandEncoderBase* topLevelEncoder,
+                                ErrorTag errorTag);
+
         MaybeError ValidateCanRecordCommands() const;
 
         // The allocator is borrowed from the top level encoder. Keep a reference to the encoder
diff --git a/src/dawn_native/RenderPassEncoder.cpp b/src/dawn_native/RenderPassEncoder.cpp
index 214b1cd..6bf0b8e 100644
--- a/src/dawn_native/RenderPassEncoder.cpp
+++ b/src/dawn_native/RenderPassEncoder.cpp
@@ -30,6 +30,17 @@
         : ProgrammablePassEncoder(device, topLevelEncoder, allocator) {
     }
 
+    RenderPassEncoderBase::RenderPassEncoderBase(DeviceBase* device,
+                                                 CommandEncoderBase* topLevelEncoder,
+                                                 ErrorTag errorTag)
+        : ProgrammablePassEncoder(device, topLevelEncoder, errorTag) {
+    }
+
+    RenderPassEncoderBase* RenderPassEncoderBase::MakeError(DeviceBase* device,
+                                                            CommandEncoderBase* topLevelEncoder) {
+        return new RenderPassEncoderBase(device, topLevelEncoder, ObjectBase::kError);
+    }
+
     void RenderPassEncoderBase::Draw(uint32_t vertexCount,
                                      uint32_t instanceCount,
                                      uint32_t firstVertex,
diff --git a/src/dawn_native/RenderPassEncoder.h b/src/dawn_native/RenderPassEncoder.h
index 08547a6..408cbe7 100644
--- a/src/dawn_native/RenderPassEncoder.h
+++ b/src/dawn_native/RenderPassEncoder.h
@@ -30,6 +30,9 @@
                               CommandEncoderBase* topLevelEncoder,
                               CommandAllocator* allocator);
 
+        static RenderPassEncoderBase* MakeError(DeviceBase* device,
+                                                CommandEncoderBase* topLevelEncoder);
+
         void Draw(uint32_t vertexCount,
                   uint32_t instanceCount,
                   uint32_t firstVertex,
@@ -60,6 +63,11 @@
                               BufferBase* const* buffers,
                               uint32_t const* offsets);
         void SetIndexBuffer(BufferBase* buffer, uint32_t offset);
+
+      protected:
+        RenderPassEncoderBase(DeviceBase* device,
+                              CommandEncoderBase* topLevelEncoder,
+                              ErrorTag errorTag);
     };
 
 }  // namespace dawn_native
diff --git a/src/tests/unittests/validation/CommandBufferValidationTests.cpp b/src/tests/unittests/validation/CommandBufferValidationTests.cpp
index 3135ba1..8ceb583 100644
--- a/src/tests/unittests/validation/CommandBufferValidationTests.cpp
+++ b/src/tests/unittests/validation/CommandBufferValidationTests.cpp
@@ -130,6 +130,31 @@
     }
 }
 
+// Test that beginning a compute pass before ending the previous pass causes an error.
+TEST_F(CommandBufferValidationTest, BeginComputePassBeforeEndPreviousPass) {
+    DummyRenderPass dummyRenderPass(device);
+
+    // Beginning a compute pass before ending a render pass causes an error.
+    {
+        dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+        dawn::RenderPassEncoder renderPass = encoder.BeginRenderPass(&dummyRenderPass);
+        dawn::ComputePassEncoder computePass = encoder.BeginComputePass();
+        computePass.EndPass();
+        renderPass.EndPass();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+
+    // Beginning a compute pass before ending a compute pass causes an error.
+    {
+        dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+        dawn::ComputePassEncoder computePass1 = encoder.BeginComputePass();
+        dawn::ComputePassEncoder computePass2 = encoder.BeginComputePass();
+        computePass2.EndPass();
+        computePass1.EndPass();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+}
+
 // Test that using a single buffer in multiple read usages in the same pass is allowed.
 TEST_F(CommandBufferValidationTest, BufferWithMultipleReadUsage) {
     // Create a buffer used as both vertex and index buffer.