Make GetBindGroupLayout error for indices past the last defined BGL.

Fixed: dawn:1565
Change-Id: I8a482623fcbd68648c451499ce769b871cf89c0a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104820
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp
index 78604f3..443ae3d 100644
--- a/src/dawn/native/Pipeline.cpp
+++ b/src/dawn/native/Pipeline.cpp
@@ -191,25 +191,25 @@
     return mStageMask;
 }
 
-MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) {
+MaybeError PipelineBase::ValidateGetBindGroupLayout(BindGroupIndex groupIndex) {
     DAWN_TRY(GetDevice()->ValidateIsAlive());
     DAWN_TRY(GetDevice()->ValidateObject(this));
     DAWN_TRY(GetDevice()->ValidateObject(mLayout.Get()));
-    DAWN_INVALID_IF(groupIndex >= kMaxBindGroups,
+    DAWN_INVALID_IF(groupIndex >= kMaxBindGroupsTyped,
                     "Bind group layout index (%u) exceeds the maximum number of bind groups (%u).",
-                    groupIndex, kMaxBindGroups);
+                    static_cast<uint32_t>(groupIndex), kMaxBindGroups);
+    DAWN_INVALID_IF(
+        !mLayout->GetBindGroupLayoutsMask()[groupIndex],
+        "Bind group layout index (%u) doesn't correspond to a bind group for this pipeline.",
+        static_cast<uint32_t>(groupIndex));
     return {};
 }
 
 ResultOrError<Ref<BindGroupLayoutBase>> PipelineBase::GetBindGroupLayout(uint32_t groupIndexIn) {
-    DAWN_TRY(ValidateGetBindGroupLayout(groupIndexIn));
-
     BindGroupIndex groupIndex(groupIndexIn);
-    if (!mLayout->GetBindGroupLayoutsMask()[groupIndex]) {
-        return Ref<BindGroupLayoutBase>(GetDevice()->GetEmptyBindGroupLayout());
-    } else {
-        return Ref<BindGroupLayoutBase>(mLayout->GetBindGroupLayout(groupIndex));
-    }
+
+    DAWN_TRY(ValidateGetBindGroupLayout(groupIndex));
+    return Ref<BindGroupLayoutBase>(mLayout->GetBindGroupLayout(groupIndex));
 }
 
 BindGroupLayoutBase* PipelineBase::APIGetBindGroupLayout(uint32_t groupIndexIn) {
diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h
index adf8f8d..fa0b65b 100644
--- a/src/dawn/native/Pipeline.h
+++ b/src/dawn/native/Pipeline.h
@@ -84,7 +84,7 @@
     explicit PipelineBase(DeviceBase* device);
 
   private:
-    MaybeError ValidateGetBindGroupLayout(uint32_t group);
+    MaybeError ValidateGetBindGroupLayout(BindGroupIndex group);
 
     wgpu::ShaderStage mStageMask = wgpu::ShaderStage::None;
     PerStage<ProgrammableStage> mStages;
diff --git a/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp b/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp
index 8434915..e687209 100644
--- a/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/BindGroupValidationTests.cpp
@@ -2062,6 +2062,51 @@
     TestComputePassBindGroup(bindGroup, nullptr, 0, false);
 }
 
+// Test that a pipeline with empty bindgroups layouts requires empty bindgroups to be set.
+TEST_F(SetBindGroupValidationTest, EmptyBindGroupsAreRequired) {
+    wgpu::BindGroupLayout emptyBGL = utils::MakeBindGroupLayout(device, {});
+    wgpu::PipelineLayout pl =
+        utils::MakePipelineLayout(device, {emptyBGL, emptyBGL, emptyBGL, emptyBGL});
+
+    wgpu::ComputePipelineDescriptor pipelineDesc;
+    pipelineDesc.layout = pl;
+    pipelineDesc.compute.entryPoint = "main";
+    pipelineDesc.compute.module = utils::CreateShaderModule(device, R"(
+        @compute @workgroup_size(1) fn main() {
+        }
+    )");
+    wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&pipelineDesc);
+
+    wgpu::BindGroup emptyBindGroup = utils::MakeBindGroup(device, emptyBGL, {});
+
+    // Control case, setting 4 empty bindgroups works.
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+        pass.SetPipeline(pipeline);
+        pass.SetBindGroup(0, emptyBindGroup);
+        pass.SetBindGroup(1, emptyBindGroup);
+        pass.SetBindGroup(2, emptyBindGroup);
+        pass.SetBindGroup(3, emptyBindGroup);
+        pass.DispatchWorkgroups(1);
+        pass.End();
+        encoder.Finish();
+    }
+
+    // Error case, setting only the first three empty bindgroups.
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+        pass.SetPipeline(pipeline);
+        pass.SetBindGroup(0, emptyBindGroup);
+        pass.SetBindGroup(1, emptyBindGroup);
+        pass.SetBindGroup(2, emptyBindGroup);
+        pass.DispatchWorkgroups(1);
+        pass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+}
+
 class SetBindGroupPersistenceValidationTest : public ValidationTest {
   protected:
     void SetUp() override {
diff --git a/src/dawn/tests/unittests/validation/GetBindGroupLayoutValidationTests.cpp b/src/dawn/tests/unittests/validation/GetBindGroupLayoutValidationTests.cpp
index 1bf2af4..dd5f478 100644
--- a/src/dawn/tests/unittests/validation/GetBindGroupLayoutValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/GetBindGroupLayoutValidationTests.cpp
@@ -980,7 +980,8 @@
                             .GetBindGroupLayout(kMaxBindGroups + 1));
 }
 
-// Test that unused indices return the empty bind group layout.
+// Test that unused indices return the empty bind group layout if before the last used index, an
+// error otherwise.
 TEST_F(GetBindGroupLayoutTests, UnusedIndex) {
     // This test works assuming Dawn Native's object deduplication.
     // Getting the same pointer to equivalent bind group layouts is an implementation detail of Dawn
@@ -1011,8 +1012,7 @@
         pipeline.GetBindGroupLayout(1).Get(), emptyBindGroupLayout.Get()));  // Not Used.
     EXPECT_FALSE(dawn::native::BindGroupLayoutBindingsEqualForTesting(
         pipeline.GetBindGroupLayout(2).Get(), emptyBindGroupLayout.Get()));  // Used.
-    EXPECT_TRUE(dawn::native::BindGroupLayoutBindingsEqualForTesting(
-        pipeline.GetBindGroupLayout(3).Get(), emptyBindGroupLayout.Get()));  // Not used
+    ASSERT_DEVICE_ERROR(pipeline.GetBindGroupLayout(3));  // Past last defined BGL, error!
 }
 
 // Test that after explicitly creating a pipeline with a pipeline layout, calling
@@ -1064,19 +1064,6 @@
     wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&pipelineDesc);
 
     EXPECT_EQ(pipeline.GetBindGroupLayout(0).Get(), bindGroupLayout.Get());
-
-    {
-        wgpu::BindGroupLayoutDescriptor emptyDesc = {};
-        emptyDesc.entryCount = 0;
-        emptyDesc.entries = nullptr;
-
-        wgpu::BindGroupLayout emptyBindGroupLayout = device.CreateBindGroupLayout(&emptyDesc);
-
-        // Check that the rest of the bind group layouts reflect the empty one.
-        EXPECT_EQ(pipeline.GetBindGroupLayout(1).Get(), emptyBindGroupLayout.Get());
-        EXPECT_EQ(pipeline.GetBindGroupLayout(2).Get(), emptyBindGroupLayout.Get());
-        EXPECT_EQ(pipeline.GetBindGroupLayout(3).Get(), emptyBindGroupLayout.Get());
-    }
 }
 
 // Test that fragment output validation is for the correct entryPoint
@@ -1123,3 +1110,29 @@
     ASSERT_DEVICE_ERROR(utils::MakeBindGroup(device, bgl0, {{1, buffer}}));
     ASSERT_DEVICE_ERROR(utils::MakeBindGroup(device, bgl1, {{0, buffer}}));
 }
+
+// Test that a pipeline full of explicitly empty BGLs correctly reflects them.
+TEST_F(GetBindGroupLayoutTests, FullOfEmptyBGLs) {
+    // This test works assuming Dawn Native's object deduplication.
+    // Getting the same pointer to equivalent bind group layouts is an implementation detail of Dawn
+    // Native.
+    DAWN_SKIP_TEST_IF(UsesWire());
+
+    wgpu::BindGroupLayout emptyBGL = utils::MakeBindGroupLayout(device, {});
+    wgpu::PipelineLayout pl =
+        utils::MakePipelineLayout(device, {emptyBGL, emptyBGL, emptyBGL, emptyBGL});
+
+    wgpu::ComputePipelineDescriptor pipelineDesc;
+    pipelineDesc.layout = pl;
+    pipelineDesc.compute.entryPoint = "main";
+    pipelineDesc.compute.module = utils::CreateShaderModule(device, R"(
+        @compute @workgroup_size(1) fn main() {
+        }
+    )");
+    wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&pipelineDesc);
+
+    EXPECT_EQ(pipeline.GetBindGroupLayout(0).Get(), emptyBGL.Get());
+    EXPECT_EQ(pipeline.GetBindGroupLayout(1).Get(), emptyBGL.Get());
+    EXPECT_EQ(pipeline.GetBindGroupLayout(2).Get(), emptyBGL.Get());
+    EXPECT_EQ(pipeline.GetBindGroupLayout(3).Get(), emptyBGL.Get());
+}