D3D12 fix for register binding offsets.

When both the vert and frag shaders have a UBO binding, the D3D12
backend was using register offset 0 for both, causing a collision,
and the wrong constant value used in one of the shaders.

The fix is to use the binding offsets computed by the BindGroupLayout,
since they know about all of the bindings, not just the ones computed
for each shader. This made it necessary to defer shader compilation
until pipeline layout creation time (as is done in the Metal backend
for similar reasons).

Finally, those bindings offsets computed by the BGL include an offset
for the CBV, UAV and SRV subgroups, so we must add the same register
offset when assigning the BaseShaderRegister to the descriptor ranges
in the PipelineLayout constructor so that they match.

Bug: dawn:20

Change-Id: I18287bf1c06f06dd61288e12da64752f54634466
Reviewed-on: https://dawn-review.googlesource.com/c/1960
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Stephen White <senorblanco@chromium.org>
diff --git a/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp b/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
index 91cd51f..7477e3f 100644
--- a/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
+++ b/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
@@ -40,7 +40,7 @@
             }
         }
 
-        auto SetDescriptorRange = [&](uint32_t index, uint32_t count,
+        auto SetDescriptorRange = [&](uint32_t index, uint32_t count, uint32_t* baseRegister,
                                       D3D12_DESCRIPTOR_RANGE_TYPE type) -> bool {
             if (count == 0) {
                 return false;
@@ -51,36 +51,35 @@
             range.NumDescriptors = count;
             range.RegisterSpace = 0;
             range.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND;
+            range.BaseShaderRegister = *baseRegister;
+            *baseRegister += count;
             // These ranges will be copied and range.BaseShaderRegister will be set in
             // d3d12::PipelineLayout to account for bind group register offsets
             return true;
         };
 
         uint32_t rangeIndex = 0;
+        uint32_t baseRegister = 0;
 
+        std::array<uint32_t, DescriptorType::Count> descriptorOffsets;
         // Ranges 0-2 contain the CBV, UAV, and SRV ranges, if they exist, tightly packed
         // Range 3 contains the Sampler range, if there is one
-        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[CBV],
+        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[CBV], &baseRegister,
                                D3D12_DESCRIPTOR_RANGE_TYPE_CBV)) {
-            rangeIndex++;
+            descriptorOffsets[CBV] = mRanges[rangeIndex++].BaseShaderRegister;
         }
-        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[UAV],
+        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[UAV], &baseRegister,
                                D3D12_DESCRIPTOR_RANGE_TYPE_UAV)) {
-            rangeIndex++;
+            descriptorOffsets[UAV] = mRanges[rangeIndex++].BaseShaderRegister;
         }
-        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[SRV],
+        if (SetDescriptorRange(rangeIndex, mDescriptorCounts[SRV], &baseRegister,
                                D3D12_DESCRIPTOR_RANGE_TYPE_SRV)) {
-            rangeIndex++;
+            descriptorOffsets[SRV] = mRanges[rangeIndex++].BaseShaderRegister;
         }
-        SetDescriptorRange(Sampler, mDescriptorCounts[Sampler],
+        uint32_t zero = 0;
+        SetDescriptorRange(Sampler, mDescriptorCounts[Sampler], &zero,
                            D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER);
-
-        // descriptors ranges are offset by the offset + size of the previous range
-        std::array<uint32_t, DescriptorType::Count> descriptorOffsets;
-        descriptorOffsets[CBV] = 0;
-        descriptorOffsets[UAV] = descriptorOffsets[CBV] + mDescriptorCounts[CBV];
-        descriptorOffsets[SRV] = descriptorOffsets[UAV] + mDescriptorCounts[UAV];
-        descriptorOffsets[Sampler] = 0;  // samplers are in a different heap
+        descriptorOffsets[Sampler] = 0;
 
         for (uint32_t binding : IterateBitSet(groupInfo.mask)) {
             switch (groupInfo.types[binding]) {
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index edcbd93..67f4cbc 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -33,7 +33,7 @@
         compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
         const ShaderModule* module = ToBackend(descriptor->module);
-        const std::string& hlslSource = module->GetHLSLSource();
+        const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout()));
 
         ComPtr<ID3DBlob> compiledShader;
         ComPtr<ID3DBlob> errors;
diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
index ee6e992..1ee21e1 100644
--- a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
@@ -65,7 +65,7 @@
 
                 for (uint32_t i = 0; i < rangeCount; ++i) {
                     ranges[rangeIndex] = descriptorRanges[i];
-                    ranges[rangeIndex].BaseShaderRegister = group * kMaxBindingsPerGroup;
+                    ranges[rangeIndex].BaseShaderRegister += group * kMaxBindingsPerGroup;
                     rangeIndex++;
                 }
 
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index b375a30..63385db 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -83,7 +83,7 @@
         for (auto stage : IterateStages(GetStageMask())) {
             const auto& module = ToBackend(builder->GetStageInfo(stage).module);
             const auto& entryPoint = builder->GetStageInfo(stage).entryPoint;
-            const auto& hlslSource = module->GetHLSLSource();
+            const auto& hlslSource = module->GetHLSLSource(ToBackend(GetLayout()));
 
             const char* compileTarget = nullptr;
 
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index f9c971e..c4bd12e 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -15,39 +15,23 @@
 #include "dawn_native/d3d12/ShaderModuleD3D12.h"
 
 #include "common/Assert.h"
+#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
+#include "dawn_native/d3d12/PipelineLayoutD3D12.h"
 
 #include <spirv-cross/spirv_hlsl.hpp>
 
 namespace dawn_native { namespace d3d12 {
 
-    // TODO(kainino@chromium.org): Consider replacing this with a generic enum_map.
-    template <typename T>
-    class BindingTypeMap {
-      public:
-        T& operator[](dawn::BindingType type) {
-            switch (type) {
-                case dawn::BindingType::UniformBuffer:
-                    return mMap[0];
-                case dawn::BindingType::Sampler:
-                    return mMap[1];
-                case dawn::BindingType::SampledTexture:
-                    return mMap[2];
-                case dawn::BindingType::StorageBuffer:
-                    return mMap[3];
-                default:
-                    DAWN_UNREACHABLE();
-            }
-        }
-
-      private:
-        static constexpr int kNumBindingTypes = 4;
-        std::array<T, kNumBindingTypes> mMap{};
-    };
-
     ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
         : ShaderModuleBase(device, descriptor) {
-        spirv_cross::CompilerHLSL compiler(descriptor->code, descriptor->codeSize);
+        mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
+        spirv_cross::CompilerHLSL compiler(mSpirv);
+        ExtractSpirvInfo(compiler);
+    }
+
+    const std::string ShaderModule::GetHLSLSource(PipelineLayout* layout) const {
+        spirv_cross::CompilerHLSL compiler(mSpirv);
 
         spirv_cross::CompilerGLSL::Options options_glsl;
         options_glsl.vertex.fixup_clipspace = true;
@@ -58,30 +42,22 @@
         options_hlsl.shader_model = 51;
         compiler.set_hlsl_options(options_hlsl);
 
-        ExtractSpirvInfo(compiler);
-
-        // rename bindings so that each register type c/u/t/s starts at 0 and then offset by
-        // kMaxBindingsPerGroup * bindGroupIndex
-        const auto& moduleBindingInfo = GetBindingInfo();
+        const ModuleBindingInfo& moduleBindingInfo = GetBindingInfo();
         for (uint32_t group = 0; group < moduleBindingInfo.size(); ++group) {
+            const auto& bindingOffsets =
+                ToBackend(layout->GetBindGroupLayout(group))->GetBindingOffsets();
             const auto& groupBindingInfo = moduleBindingInfo[group];
-
-            BindingTypeMap<uint32_t> baseRegisters{};
-            for (const auto& bindingInfo : groupBindingInfo) {
+            for (uint32_t binding = 0; binding < groupBindingInfo.size(); ++binding) {
+                const BindingInfo& bindingInfo = groupBindingInfo[binding];
                 if (bindingInfo.used) {
-                    uint32_t& baseRegister = baseRegisters[bindingInfo.type];
                     uint32_t bindGroupOffset = group * kMaxBindingsPerGroup;
+                    uint32_t bindingOffset = bindingOffsets[binding];
                     compiler.set_decoration(bindingInfo.id, spv::DecorationBinding,
-                                            bindGroupOffset + baseRegister++);
+                                            bindGroupOffset + bindingOffset);
                 }
             }
         }
-
-        mHlslSource = compiler.compile();
-    }
-
-    const std::string& ShaderModule::GetHLSLSource() const {
-        return mHlslSource;
+        return compiler.compile();
     }
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 11065c1..7cafd1c 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -20,15 +20,16 @@
 namespace dawn_native { namespace d3d12 {
 
     class Device;
+    class PipelineLayout;
 
     class ShaderModule : public ShaderModuleBase {
       public:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
 
-        const std::string& GetHLSLSource() const;
+        const std::string GetHLSLSource(PipelineLayout* layout) const;
 
       private:
-        std::string mHlslSource;
+        std::vector<uint32_t> mSpirv;
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp
index 8585151..a8a26c7 100644
--- a/src/tests/end2end/BindGroupTests.cpp
+++ b/src/tests/end2end/BindGroupTests.cpp
@@ -12,9 +12,12 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include "common/Constants.h"
 #include "tests/DawnTest.h"
 #include "utils/DawnHelpers.h"
 
+constexpr static unsigned int kRTSize = 8;
+
 class BindGroupTests : public DawnTest {
 protected:
     dawn::CommandBuffer CreateSimpleComputeCommandBuffer(
@@ -35,7 +38,7 @@
     dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
         device,
         {
-            {0, dawn::ShaderStageBit::Compute, dawn::BindingType::UniformBuffer },
+            {0, dawn::ShaderStageBit::Compute, dawn::BindingType::UniformBuffer},
         }
     );
     dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
@@ -75,4 +78,195 @@
     queue.Submit(2, cb);
 }
 
+// Test a bindgroup containing a UBO which is used in both the vertex and fragment shader.
+// It contains a transformation matrix for the VS and the fragment color for the FS.
+// These must result in different register offsets in the native APIs.
+TEST_P(BindGroupTests, ReusedUBO) {
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+        #version 450
+        layout (set = 0, binding = 0) uniform vertexUniformBuffer {
+            mat2 transform;
+        };
+        void main() {
+            const vec2 pos[3] = vec2[3](vec2(-1.f, -1.f), vec2(1.f, -1.f), vec2(-1.f, 1.f));
+            gl_Position = vec4(transform * pos[gl_VertexIndex], 0.f, 1.f);
+        })"
+    );
+
+    dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+        #version 450
+        layout (set = 0, binding = 1) uniform fragmentUniformBuffer {
+            vec4 color;
+        };
+        layout(location = 0) out vec4 fragColor;
+        void main() {
+            fragColor = color;
+        })"
+    );
+
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device,
+        {
+            {0, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer},
+            {1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer},
+        }
+    );
+    dawn::PipelineLayout pipelineLayout = utils::MakeBasicPipelineLayout(device, &bgl);
+
+    dawn::RenderPipeline pipeline = device.CreateRenderPipelineBuilder()
+        .SetColorAttachmentFormat(0, renderPass.colorFormat)
+        .SetLayout(pipelineLayout)
+        .SetPrimitiveTopology(dawn::PrimitiveTopology::TriangleList)
+        .SetStage(dawn::ShaderStage::Vertex, vsModule, "main")
+        .SetStage(dawn::ShaderStage::Fragment, fsModule, "main")
+        .GetResult();
+    struct Data {
+        float transform[8];
+        char padding[256 - sizeof(Data::transform)];
+        float color[4];
+    };
+    constexpr float dummy = 0.0f;
+    Data data {
+        { 1.f, 0.f, dummy, dummy, 0.f, 1.0f, dummy, dummy },
+        { 0 },
+        { 0.f, 1.f, 0.f, 1.f },
+    };
+    dawn::Buffer buffer = utils::CreateBufferFromData(device, &data, sizeof(data), dawn::BufferUsageBit::Uniform);
+    dawn::BufferView vertUBOBufferView =
+        buffer.CreateBufferViewBuilder().SetExtent(offsetof(Data, transform), sizeof(Data::transform)).GetResult();
+    dawn::BufferView fragUBOBufferView =
+        buffer.CreateBufferViewBuilder().SetExtent(offsetof(Data, color), sizeof(Data::color)).GetResult();
+    dawn::BindGroup bindGroup = device.CreateBindGroupBuilder()
+        .SetLayout(bgl)
+        .SetBufferViews(0, 1, &vertUBOBufferView)
+        .SetBufferViews(1, 1, &fragUBOBufferView)
+        .GetResult();
+
+    dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+    dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+    pass.SetRenderPipeline(pipeline);
+    pass.SetBindGroup(0, bindGroup);
+    pass.DrawArrays(3, 1, 0, 0);
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = builder.GetResult();
+    queue.Submit(1, &commands);
+
+    RGBA8 filled(0, 255, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
+// Test a bindgroup containing a UBO in the vertex shader and a sampler and texture in the fragment shader.
+// In D3D12 for example, these different types of bindings end up in different namespaces, but the register
+// offsets used must match between the shader module and descriptor range.
+TEST_P(BindGroupTests, UBOSamplerAndTexture) {
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+        #version 450
+        layout (set = 0, binding = 0) uniform vertexUniformBuffer {
+            mat2 transform;
+        };
+        void main() {
+            const vec2 pos[3] = vec2[3](vec2(-1.f, -1.f), vec2(1.f, -1.f), vec2(-1.f, 1.f));
+            gl_Position = vec4(transform * pos[gl_VertexIndex], 0.f, 1.f);
+        })"
+    );
+
+    dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+        #version 450
+        layout (set = 0, binding = 1) uniform sampler samp;
+        layout (set = 0, binding = 2) uniform texture2D tex;
+        layout (location = 0) out vec4 fragColor;
+        void main() {
+            fragColor = texture(sampler2D(tex, samp), gl_FragCoord.xy);
+        })"
+    );
+
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device,
+        {
+            {0, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer},
+            {1, dawn::ShaderStageBit::Fragment, dawn::BindingType::Sampler},
+            {2, dawn::ShaderStageBit::Fragment, dawn::BindingType::SampledTexture},
+        }
+    );
+    dawn::PipelineLayout pipelineLayout = utils::MakeBasicPipelineLayout(device, &bgl);
+
+    dawn::RenderPipeline pipeline = device.CreateRenderPipelineBuilder()
+        .SetColorAttachmentFormat(0, renderPass.colorFormat)
+        .SetLayout(pipelineLayout)
+        .SetPrimitiveTopology(dawn::PrimitiveTopology::TriangleList)
+        .SetStage(dawn::ShaderStage::Vertex, vsModule, "main")
+        .SetStage(dawn::ShaderStage::Fragment, fsModule, "main")
+        .GetResult();
+    constexpr float dummy = 0.0f;
+    constexpr float transform[] = { 1.f, 0.f, dummy, dummy, 0.f, 1.f, dummy, dummy };
+    dawn::Buffer buffer = utils::CreateBufferFromData(device, &transform, sizeof(transform), dawn::BufferUsageBit::Uniform);
+    dawn::BufferView vertUBOBufferView =
+        buffer.CreateBufferViewBuilder().SetExtent(0, sizeof(transform)).GetResult();
+    dawn::SamplerDescriptor samplerDescriptor;
+    samplerDescriptor.minFilter = dawn::FilterMode::Nearest;
+    samplerDescriptor.magFilter = dawn::FilterMode::Nearest;
+    samplerDescriptor.mipmapFilter = dawn::FilterMode::Nearest;
+    samplerDescriptor.addressModeU = dawn::AddressMode::ClampToEdge;
+    samplerDescriptor.addressModeV = dawn::AddressMode::ClampToEdge;
+    samplerDescriptor.addressModeW = dawn::AddressMode::ClampToEdge;
+    dawn::Sampler sampler = device.CreateSampler(&samplerDescriptor);
+
+    dawn::TextureDescriptor descriptor;
+    descriptor.dimension = dawn::TextureDimension::e2D;
+    descriptor.size.width = kRTSize;
+    descriptor.size.height = kRTSize;
+    descriptor.size.depth = 1;
+    descriptor.arrayLayer = 1;
+    descriptor.format = dawn::TextureFormat::R8G8B8A8Unorm;
+    descriptor.levelCount = 1;
+    descriptor.usage = dawn::TextureUsageBit::TransferDst | dawn::TextureUsageBit::Sampled;
+    dawn::Texture texture = device.CreateTexture(&descriptor);
+    dawn::TextureView textureView = texture.CreateDefaultTextureView();
+    int width = kRTSize, height = kRTSize;
+    int widthInBytes = width * sizeof(RGBA8);
+    widthInBytes = (widthInBytes + 255) & ~255;
+    int sizeInBytes = widthInBytes * height;
+    int size = sizeInBytes / sizeof(RGBA8);
+    std::vector<RGBA8> data = std::vector<RGBA8>(size);
+    for (int i = 0; i < size; i++) {
+        data[i] = RGBA8(0, 255, 0, 255);
+    }
+    dawn::Buffer stagingBuffer = utils::CreateBufferFromData(device, data.data(), sizeInBytes, dawn::BufferUsageBit::TransferSrc);
+    dawn::BindGroup bindGroup = device.CreateBindGroupBuilder()
+        .SetLayout(bgl)
+        .SetBufferViews(0, 1, &vertUBOBufferView)
+        .SetSamplers(1, 1, &sampler)
+        .SetTextureViews(2, 1, &textureView)
+        .GetResult();
+
+    dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+    builder.CopyBufferToTexture(stagingBuffer, 0, widthInBytes, texture, 0, 0, 0, width, height, 1, 0, 0);
+    dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+    pass.SetRenderPipeline(pipeline);
+    pass.SetBindGroup(0, bindGroup);
+    pass.DrawArrays(3, 1, 0, 0);
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = builder.GetResult();
+    queue.Submit(1, &commands);
+
+    RGBA8 filled(0, 255, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color,    min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
 DAWN_INSTANTIATE_TEST(BindGroupTests, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);