D3D12: Move transform/FirstIndexOffset outside of TranslateToHLSL

This is a minor refactor that runs transform/FirstIndexOffset in a
separate pass, outside of TranslateToHLSL. It should be functionally the
same as what currently exists.

This is to prepare for creating pipeline cache keys based on the WGSL.

Bug: dawn:1103
Change-Id: Ifc516079bafe2449d422f8bd8485b2459cd3d181
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63224
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Michael Tang <tangm@microsoft.com>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 426cdb3..718e0c1 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -181,11 +181,11 @@
     }
 
     ResultOrError<std::string> ShaderModule::TranslateToHLSL(
+        const tint::Program* program,
         const char* entryPointName,
         SingleShaderStage stage,
         PipelineLayout* layout,
-        std::string* remappedEntryPointName,
-        FirstOffsetInfo* firstOffsetInfo) const {
+        std::string* remappedEntryPointName) const {
         ASSERT(!IsError());
 
         ScopedTintICEHandler scopedICEHandler(GetDevice());
@@ -244,17 +244,6 @@
         }
         transformManager.Add<tint::transform::BindingRemapper>();
 
-        // The FirstIndexOffset transform must be done after the BindingRemapper because it assumes
-        // that the register space has already flattened (and uses the next register). Otherwise
-        // intermediate ASTs can be produced where the extra registers conflict with one of the
-        // user-declared bind points.
-        if (stage == SingleShaderStage::Vertex) {
-            transformManager.Add<tint::transform::FirstIndexOffset>();
-            transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
-                layout->GetFirstIndexOffsetShaderRegister(),
-                layout->GetFirstIndexOffsetRegisterSpace());
-        }
-
         transformManager.Add<tint::transform::Renamer>();
 
         if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
@@ -269,21 +258,11 @@
         transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
                                                          std::move(accessControls), mayCollide);
 
-        tint::Program program;
+        tint::Program transformedProgram;
         tint::transform::DataMap transformOutputs;
-        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
-                                               &transformOutputs, nullptr));
-
-        if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
-            firstOffsetInfo->usesVertexIndex = data->has_vertex_index;
-            if (firstOffsetInfo->usesVertexIndex) {
-                firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset;
-            }
-            firstOffsetInfo->usesInstanceIndex = data->has_instance_index;
-            if (firstOffsetInfo->usesInstanceIndex) {
-                firstOffsetInfo->instanceIndexOffset = data->first_instance_offset;
-            }
-        }
+        DAWN_TRY_ASSIGN(
+            transformedProgram,
+            RunTransforms(&transformManager, program, transformInputs, &transformOutputs, nullptr));
 
         if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
             auto it = data->remappings.find(entryPointName);
@@ -302,7 +281,7 @@
 
         tint::writer::hlsl::Options options;
         options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
-        auto result = tint::writer::hlsl::Generate(&program, options);
+        auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
         if (!result.success) {
             errorStream << "Generator: " << result.error << std::endl;
             return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
@@ -317,13 +296,48 @@
                                                         uint32_t compileFlags) {
         Device* device = ToBackend(GetDevice());
 
+        CompiledShader compiledShader = {};
+
+        tint::transform::Manager transformManager;
+        tint::transform::DataMap transformInputs;
+
+        const tint::Program* program;
+        tint::Program programAsValue;
+        if (stage == SingleShaderStage::Vertex) {
+            transformManager.Add<tint::transform::FirstIndexOffset>();
+            transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
+                layout->GetFirstIndexOffsetShaderRegister(),
+                layout->GetFirstIndexOffsetRegisterSpace());
+
+            tint::transform::DataMap transformOutputs;
+            DAWN_TRY_ASSIGN(programAsValue,
+                            RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+                                          &transformOutputs, nullptr));
+
+            if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
+                // TODO(dawn:549): Consider adding this information to the pipeline cache once we
+                // can store more than the shader blob in it.
+                compiledShader.firstOffsetInfo.usesVertexIndex = data->has_vertex_index;
+                if (compiledShader.firstOffsetInfo.usesVertexIndex) {
+                    compiledShader.firstOffsetInfo.vertexIndexOffset = data->first_vertex_offset;
+                }
+                compiledShader.firstOffsetInfo.usesInstanceIndex = data->has_instance_index;
+                if (compiledShader.firstOffsetInfo.usesInstanceIndex) {
+                    compiledShader.firstOffsetInfo.instanceIndexOffset =
+                        data->first_instance_offset;
+                }
+            }
+
+            program = &programAsValue;
+        } else {
+            program = GetTintProgram();
+        }
+
         // Compile the source shader to HLSL.
         std::string hlslSource;
         std::string remappedEntryPoint;
-        CompiledShader compiledShader = {};
-        DAWN_TRY_ASSIGN(hlslSource,
-                        TranslateToHLSL(entryPointName, stage, layout, &remappedEntryPoint,
-                                        &compiledShader.firstOffsetInfo));
+        DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(program, entryPointName, stage, layout,
+                                                    &remappedEntryPoint));
         entryPointName = remappedEntryPoint.c_str();
 
         if (device->IsToggleEnabled(Toggle::DumpShaders)) {
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 7fbb78a..436d7bd 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -59,11 +59,11 @@
         ~ShaderModule() override = default;
         MaybeError Initialize(ShaderModuleParseResult* parseResult);
 
-        ResultOrError<std::string> TranslateToHLSL(const char* entryPointName,
+        ResultOrError<std::string> TranslateToHLSL(const tint::Program* program,
+                                                   const char* entryPointName,
                                                    SingleShaderStage stage,
                                                    PipelineLayout* layout,
-                                                   std::string* remappedEntryPointName,
-                                                   FirstOffsetInfo* firstOffsetInfo) const;
+                                                   std::string* remappedEntryPointName) const;
 
         ResultOrError<PersistentCacheKey> CreateHLSLKey(const char* entryPointName,
                                                         SingleShaderStage stage,