Use Tint HLSL options to pass dynamic storage buffer sizes

Dynamic storage buffer sizes are loaded from a uniform buffer
which is bound to a set of root constants in the D3D12 root
signature.

Bug: dawn:429
Change-Id: I3bf0d9bbdb7a5b0a8c0f624f18081c6bf8d45fca
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68960
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index b70c676..3001c6c 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -77,6 +77,12 @@
             output << ")";
         }
 
+        template <typename T,
+                  typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
+        void Serialize(std::stringstream& output, const T& val) {
+            output << val;
+        }
+
         template <typename T>
         void Serialize(std::stringstream& output,
                        const std::unordered_map<tint::transform::BindingPoint, T>& map) {
@@ -93,6 +99,16 @@
             output << ")";
         }
 
+        void Serialize(std::stringstream& output,
+                       const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
+            output << "(ArrayLengthFromUniformOptions";
+            output << " ubo_binding=";
+            Serialize(output, arrayLengthFromUniform.ubo_binding);
+            output << " bindpoint_to_size_index=";
+            Serialize(output, arrayLengthFromUniform.bindpoint_to_size_index);
+            output << ")";
+        }
+
         // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
         std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
             std::ostringstream out;
@@ -172,12 +188,13 @@
             SingleShaderStage stage;
             uint32_t compileFlags;
             bool disableSymbolRenaming;
-            tint::transform::BindingRemapper::BindingPoints bindingPoints;
-            tint::transform::BindingRemapper::AccessControls accessControls;
+            tint::transform::BindingRemapper::BindingPoints remappedBindingPoints;
+            tint::transform::BindingRemapper::AccessControls remappedAccessControls;
             bool isRobustnessEnabled;
             bool usesNumWorkgroups;
             uint32_t numWorkgroupsRegisterSpace;
             uint32_t numWorkgroupsShaderRegister;
+            tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
             std::vector<std::pair<std::string, std::string>> defineStrings;
 
             // FXC/DXC common inputs
@@ -212,17 +229,23 @@
                 using tint::transform::BindingPoint;
                 using tint::transform::BindingRemapper;
 
-                BindingRemapper::BindingPoints bindingPoints;
-                BindingRemapper::AccessControls accessControls;
+                BindingRemapper::BindingPoints remappedBindingPoints;
+                BindingRemapper::AccessControls remappedAccessControls;
 
-                // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify the
-                // Tint AST to make the "bindings" decoration match the offset chosen by
-                // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
-                // assigned to each interface variable.
+                tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
+                arrayLengthFromUniform.ubo_binding = {
+                    layout->GetDynamicStorageBufferLengthsRegisterSpace(),
+                    layout->GetDynamicStorageBufferLengthsShaderRegister()};
+
                 const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
                 for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
                     const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
                     const auto& groupBindingInfo = moduleBindingInfo[group];
+
+                    // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
+                    // the Tint AST to make the "bindings" decoration match the offset chosen by
+                    // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
+                    // assigned to each interface variable.
                     for (const auto& it : groupBindingInfo) {
                         BindingNumber binding = it.first;
                         auto const& bindingInfo = it.second;
@@ -232,7 +255,7 @@
                         BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
                                                      bgl->GetShaderRegister(bindingIndex)};
                         if (srcBindingPoint != dstBindingPoint) {
-                            bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+                            remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
                         }
 
                         // Declaring a read-only storage buffer in HLSL but specifying a storage
@@ -246,7 +269,29 @@
                               bgl->GetBindingInfo(bindingIndex).buffer.type ==
                                   kInternalStorageBufferBinding));
                         if (forceStorageBufferAsUAV) {
-                            accessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
+                            remappedAccessControls.emplace(srcBindingPoint,
+                                                           tint::ast::Access::kReadWrite);
+                        }
+                    }
+
+                    // Add arrayLengthFromUniform options
+                    {
+                        for (const auto& bindingAndRegisterOffset :
+                             layout->GetDynamicStorageBufferLengthInfo()[group]
+                                 .bindingAndRegisterOffsets) {
+                            BindingNumber binding = bindingAndRegisterOffset.binding;
+                            uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
+
+                            BindingPoint bindingPoint{static_cast<uint32_t>(group),
+                                                      static_cast<uint32_t>(binding)};
+                            // Get the renamed binding point if it was remapped.
+                            auto it = remappedBindingPoints.find(bindingPoint);
+                            if (it != remappedBindingPoints.end()) {
+                                bindingPoint = it->second;
+                            }
+
+                            arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
+                                                                                   registerOffset);
                         }
                     }
                 }
@@ -259,14 +304,15 @@
                 request.compileFlags = compileFlags;
                 request.disableSymbolRenaming =
                     device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
-                request.bindingPoints = std::move(bindingPoints);
-                request.accessControls = std::move(accessControls);
+                request.remappedBindingPoints = std::move(remappedBindingPoints);
+                request.remappedAccessControls = std::move(remappedAccessControls);
                 request.isRobustnessEnabled = device->IsRobustnessEnabled();
                 request.disableWorkgroupInit =
                     device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
                 request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
                 request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
                 request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
+                request.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
                 request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
                 request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
                 request.deviceInfo = &device->GetDeviceInfo();
@@ -312,16 +358,19 @@
                 stream << " compileFlags=" << compileFlags;
                 stream << " disableSymbolRenaming=" << disableSymbolRenaming;
 
-                stream << " bindingPoints=";
-                Serialize(stream, bindingPoints);
+                stream << " remappedBindingPoints=";
+                Serialize(stream, remappedBindingPoints);
 
-                stream << " accessControls=";
-                Serialize(stream, accessControls);
+                stream << " remappedAccessControls=";
+                Serialize(stream, remappedAccessControls);
 
                 stream << " useNumWorkgroups=" << usesNumWorkgroups;
                 stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
                 stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
 
+                stream << " arrayLengthFromUniform=";
+                Serialize(stream, arrayLengthFromUniform);
+
                 stream << " shaderModel=" << deviceInfo->shaderModel;
                 stream << " disableWorkgroupInit=" << disableWorkgroupInit;
                 stream << " isRobustnessEnabled=" << isRobustnessEnabled;
@@ -564,6 +613,7 @@
             if (request.isRobustnessEnabled) {
                 transformManager.Add<tint::transform::Robustness>();
             }
+
             transformManager.Add<tint::transform::BindingRemapper>();
 
             transformManager.Add<tint::transform::SingleEntryPoint>();
@@ -582,7 +632,8 @@
             // different types.
             const bool mayCollide = true;
             transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
-                std::move(request.bindingPoints), std::move(request.accessControls), mayCollide);
+                std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls),
+                mayCollide);
 
             tint::Program transformedProgram;
             tint::transform::DataMap transformOutputs;
@@ -610,6 +661,12 @@
                 options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
                 options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
             }
+            // TODO(dawn:549): HLSL generation outputs the indices into the
+            // array_length_from_uniform buffer that were actually used. When the blob cache can
+            // store more than compiled shaders, we should reflect these used indices and store
+            // them as well. This would allow us to only upload root constants that are actually
+            // read by the shader.
+            options.array_length_from_uniform = request.arrayLengthFromUniform;
             auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
             DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s",
                             result.error);
diff --git a/src/tests/end2end/DynamicBufferOffsetTests.cpp b/src/tests/end2end/DynamicBufferOffsetTests.cpp
index 0c0f3eb..c55ef5a 100644
--- a/src/tests/end2end/DynamicBufferOffsetTests.cpp
+++ b/src/tests/end2end/DynamicBufferOffsetTests.cpp
@@ -414,11 +414,6 @@
 
 // Test robust buffer access behavior for out of bounds accesses to dynamic buffer bindings.
 TEST_P(ClampedOOBDynamicBufferOffsetTests, CheckOOBAccess) {
-    // TODO(crbug.com/dawn/429): Dynamic storage buffers are not bounds clamped on D3D12.
-    DAWN_SUPPRESS_TEST_IF(IsD3D12() && ((GetParam().mOOBRead && GetParam().mReadBufferUsage ==
-                                                                    wgpu::BufferUsage::Storage) ||
-                                        GetParam().mOOBWrite));
-
     static constexpr uint32_t kArrayLength = 10u;
 
     // Out-of-bounds access will start halfway into the array and index off the end.