Cleanup transform usage

Use tint::transform::DataMap for inputs as well as outputs.

This allows tint to nest transforms inside each other (e.g. embedding
transforms inside sanitizers), and still having a consistent way to pass
data in and out of these transforms, regardless of nesting depth.

Transforms can also now be fully pre-built and used multiple times as
there is no state held by the transform itself.

Bug: tint:389

Change-Id: If1616c77f2776be449021a32f4a6b0b89159aa2a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/48060
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 90115f0..a9cceac 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -1110,11 +1110,13 @@
                     parseResult->tintSource = std::move(tintSource);
                 } else {
                     tint::transform::Manager transformManager;
-                    transformManager.append(
-                        std::make_unique<tint::transform::EmitVertexPointSize>());
-                    transformManager.append(std::make_unique<tint::transform::Spirv>());
-                    DAWN_TRY_ASSIGN(program,
-                                    RunTransforms(&transformManager, &program, outMessages));
+                    transformManager.Add<tint::transform::EmitVertexPointSize>();
+                    transformManager.Add<tint::transform::Spirv>();
+
+                    tint::transform::DataMap transformInputs;
+
+                    DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
+                                                           transformInputs, nullptr, outMessages));
 
                     std::vector<uint32_t> spirv;
                     DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
@@ -1144,8 +1146,10 @@
 
     ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
                                                const tint::Program* program,
+                                               const tint::transform::DataMap& inputs,
+                                               tint::transform::DataMap* outputs,
                                                OwnedCompilationMessages* outMessages) {
-        tint::transform::Transform::Output output = transform->Run(program);
+        tint::transform::Output output = transform->Run(program, inputs);
         if (outMessages != nullptr) {
             outMessages->AddMessages(output.program.Diagnostics());
         }
@@ -1153,13 +1157,16 @@
             std::string err = "Tint program failure: " + output.program.Diagnostics().str();
             return DAWN_VALIDATION_ERROR(err.c_str());
         }
+        if (outputs != nullptr) {
+            *outputs = std::move(output.data);
+        }
         return std::move(output.program);
     }
 
-    std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
-        const VertexState& vertexState,
-        const std::string& entryPoint,
-        BindGroupIndex pullingBufferBindingSet) {
+    void AddVertexPullingTransformConfig(const VertexState& vertexState,
+                                         const std::string& entryPoint,
+                                         BindGroupIndex pullingBufferBindingSet,
+                                         tint::transform::DataMap* transformInputs) {
         tint::transform::VertexPulling::Config cfg;
         cfg.entry_point_name = entryPoint;
         cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
@@ -1181,7 +1188,7 @@
 
             cfg.vertex_state.push_back(std::move(layout));
         }
-        return std::make_unique<tint::transform::VertexPulling>(cfg);
+        transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
     }
 
     MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
@@ -1314,19 +1321,23 @@
         errorStream << "Tint vertex pulling failure:" << std::endl;
 
         tint::transform::Manager transformManager;
-        transformManager.append(
-            MakeVertexPullingTransform(vertexState, entryPoint, pullingBufferBindingSet));
-        transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
-        transformManager.append(std::make_unique<tint::transform::Spirv>());
+        transformManager.Add<tint::transform::VertexPulling>();
+        transformManager.Add<tint::transform::EmitVertexPointSize>();
+        transformManager.Add<tint::transform::Spirv>();
         if (GetDevice()->IsRobustnessEnabled()) {
-            transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
+            transformManager.Add<tint::transform::BoundArrayAccessors>();
         }
 
+        tint::transform::DataMap transformInputs;
+        AddVertexPullingTransformConfig(vertexState, entryPoint, pullingBufferBindingSet,
+                                        &transformInputs);
+
         // A nullptr is passed in for the CompilationMessages here since this method is called
-        // during RenderPipeline creation, by which point the shader module's CompilationInfo may
-        // have already been queried.
+        // during RenderPipeline creation, by which point the shader module's CompilationInfo
+        // may have already been queried.
         tint::Program program;
-        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, nullptr));
+        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, transformInputs,
+                                               nullptr, nullptr));
 
         tint::writer::spirv::Generator generator(&program);
         if (!generator.Generate()) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 556d604..2534792 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -37,6 +37,7 @@
     class Program;
 
     namespace transform {
+        class DataMap;
         class Transform;
         class VertexPulling;
     }  // namespace transform
@@ -88,12 +89,15 @@
                                                             const PipelineLayoutBase* layout);
     ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
                                                const tint::Program* program,
+                                               const tint::transform::DataMap& inputs,
+                                               tint::transform::DataMap* outputs,
                                                OwnedCompilationMessages* messages);
 
-    std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
-        const VertexState& vertexState,
-        const std::string& entryPoint,
-        BindGroupIndex pullingBufferBindingSet);
+    /// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
+    void AddVertexPullingTransformConfig(const VertexState& vertexState,
+                                         const std::string& entryPoint,
+                                         BindGroupIndex pullingBufferBindingSet,
+                                         tint::transform::DataMap* transformInputs);
 
     // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
     // stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
@@ -173,7 +177,7 @@
             const std::string& entryPoint,
             BindGroupIndex pullingBufferBindingSet) const;
 
-        OwnedCompilationMessages* CompilationMessages() {
+        OwnedCompilationMessages* GetCompilationMessages() {
             return mCompilationMessages.get();
         }
 
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 8b14210..57b1b51 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -194,7 +194,7 @@
         SingleShaderStage stage,
         PipelineLayout* layout,
         std::string* remappedEntryPointName,
-        FirstOffsetInfo* firstOffsetInfo) const {
+        FirstOffsetInfo* firstOffsetInfo) {
         ASSERT(!IsError());
 
         ScopedTintICEHandler scopedICEHandler(GetDevice());
@@ -245,30 +245,28 @@
         errorStream << "Tint HLSL failure:" << std::endl;
 
         tint::transform::Manager transformManager;
-        transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
-        if (stage == SingleShaderStage::Vertex) {
-            transformManager.append(std::make_unique<tint::transform::FirstIndexOffset>(
-                layout->GetFirstIndexOffsetShaderRegister(),
-                layout->GetFirstIndexOffsetRegisterSpace()));
-        }
-        transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
-        transformManager.append(std::make_unique<tint::transform::Renamer>());
-        transformManager.append(std::make_unique<tint::transform::Hlsl>());
-
         tint::transform::DataMap transformInputs;
+
+        transformManager.Add<tint::transform::BoundArrayAccessors>();
+        if (stage == SingleShaderStage::Vertex) {
+            transformManager.Add<tint::transform::FirstIndexOffset>();
+            transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
+                layout->GetFirstIndexOffsetShaderRegister(),
+                layout->GetFirstIndexOffsetRegisterSpace());
+        }
+        transformManager.Add<tint::transform::BindingRemapper>();
+        transformManager.Add<tint::transform::Renamer>();
+        transformManager.Add<tint::transform::Hlsl>();
+
         transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
                                                          std::move(accessControls));
-        tint::transform::Transform::Output output =
-            transformManager.Run(GetTintProgram(), transformInputs);
 
-        const tint::Program& program = output.program;
-        if (!program.IsValid()) {
-            errorStream << "Tint program transform error: " << program.Diagnostics().str()
-                        << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
+        tint::Program program;
+        tint::transform::DataMap transformOutputs;
+        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+                                               &transformOutputs, nullptr));
 
-        if (auto* data = output.data.Get<tint::transform::FirstIndexOffset::Data>()) {
+        if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
             firstOffsetInfo->usesVertexIndex = data->has_vertex_index;
             if (firstOffsetInfo->usesVertexIndex) {
                 firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset;
@@ -279,7 +277,7 @@
             }
         }
 
-        if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
+        if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
             auto it = data->remappings.find(entryPointName);
             if (it == data->remappings.end()) {
                 return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 98eed9e..d85796d 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -63,7 +63,7 @@
                                                            SingleShaderStage stage,
                                                            PipelineLayout* layout,
                                                            std::string* remappedEntryPointName,
-                                                           FirstOffsetInfo* firstOffsetInfo) const;
+                                                           FirstOffsetInfo* firstOffsetInfo);
 
         ResultOrError<std::string> TranslateToHLSLWithSPIRVCross(const char* entryPointName,
                                                                  SingleShaderStage stage,
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index d302ce7..f88dcc7 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -70,10 +70,12 @@
         errorStream << "Tint MSL failure:" << std::endl;
 
         tint::transform::Manager transformManager;
+        tint::transform::DataMap transformInputs;
+
         if (stage == SingleShaderStage::Vertex &&
             GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
-            transformManager.append(
-                MakeVertexPullingTransform(*vertexState, entryPointName, kPullingBufferBindingSet));
+            AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet,
+                                            &transformInputs);
 
             for (VertexBufferSlot slot :
                  IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@@ -83,20 +85,16 @@
                 // this MSL buffer index.
             }
         }
-        transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
-        transformManager.append(std::make_unique<tint::transform::Renamer>());
-        transformManager.append(std::make_unique<tint::transform::Msl>());
+        transformManager.Add<tint::transform::BoundArrayAccessors>();
+        transformManager.Add<tint::transform::Renamer>();
+        transformManager.Add<tint::transform::Msl>();
 
-        tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
+        tint::Program program;
+        tint::transform::DataMap transformOutputs;
+        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+                                               &transformOutputs, nullptr));
 
-        tint::Program& program = output.program;
-        if (!program.IsValid()) {
-            errorStream << "Tint program transform error: " << program.Diagnostics().str()
-                        << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
-        if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
+        if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
             auto it = data->remappings.find(entryPointName);
             if (it == data->remappings.end()) {
                 return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 0bd1895..8d8c311 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -87,9 +87,12 @@
             tint::transform::Manager transformManager;
             transformManager.append(std::make_unique<tint::transform::Spirv>());
 
+            tint::transform::DataMap transformInputs;
+
             tint::Program program;
-            DAWN_TRY_ASSIGN(
-                program, RunTransforms(&transformManager, GetTintProgram(), CompilationMessages()));
+            DAWN_TRY_ASSIGN(program,
+                            RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+                                          nullptr, GetCompilationMessages()));
 
             tint::writer::spirv::Generator generator(&program);
             if (!generator.Generate()) {
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index b8a2b23..9f4da48 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -55,14 +55,16 @@
             errorStream << "Tint SPIR-V writer failure:" << std::endl;
 
             tint::transform::Manager transformManager;
-            transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
-            transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
-            transformManager.append(std::make_unique<tint::transform::Spirv>());
+            transformManager.Add<tint::transform::BoundArrayAccessors>();
+            transformManager.Add<tint::transform::EmitVertexPointSize>();
+            transformManager.Add<tint::transform::Spirv>();
+
+            tint::transform::DataMap transformInputs;
 
             tint::Program program;
             DAWN_TRY_ASSIGN(program,
                             RunTransforms(&transformManager, parseResult->tintProgram.get(),
-                                          CompilationMessages()));
+                                          transformInputs, nullptr, GetCompilationMessages()));
 
             tint::writer::spirv::Generator generator(&program);
             if (!generator.Generate()) {
@@ -166,15 +168,10 @@
         tint::transform::DataMap transformInputs;
         transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
                                                          std::move(accessControls));
-        tint::transform::Transform::Output output =
-            transformManager.Run(GetTintProgram(), transformInputs);
 
-        const tint::Program& program = output.program;
-        if (!program.IsValid()) {
-            errorStream << "Tint program transform error: " << program.Diagnostics().str()
-                        << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
+        tint::Program program;
+        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+                                               nullptr, nullptr));
 
         tint::writer::spirv::Generator generator(&program);
         if (!generator.Generate()) {
diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 30754bb..802903f 100644
--- a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -160,7 +160,7 @@
 }
 
 // Tests that shader module compilation messages can be queried.
-TEST_F(ShaderModuleValidationTest, CompilationMessages) {
+TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
     // This test works assuming ShaderModule is backed by a dawn_native::ShaderModuleBase, which
     // is not the case on the wire.
     DAWN_SKIP_TEST_IF(UsesWire());
@@ -172,12 +172,11 @@
 
     dawn_native::ShaderModuleBase* shaderModuleBase =
         reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get());
-    shaderModuleBase->CompilationMessages()->ClearMessages();
-    shaderModuleBase->CompilationMessages()->AddMessage("Info Message");
-    shaderModuleBase->CompilationMessages()->AddMessage("Warning Message",
-                                                        wgpu::CompilationMessageType::Warning);
-    shaderModuleBase->CompilationMessages()->AddMessage("Error Message",
-                                                        wgpu::CompilationMessageType::Error, 3, 4);
+    dawn_native::OwnedCompilationMessages* messages = shaderModuleBase->GetCompilationMessages();
+    messages->ClearMessages();
+    messages->AddMessage("Info Message");
+    messages->AddMessage("Warning Message", wgpu::CompilationMessageType::Warning);
+    messages->AddMessage("Error Message", wgpu::CompilationMessageType::Error, 3, 4);
 
     auto callback = [](WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info,
                        void* userdata) {