[ir] Convert ShaderIOSpirv to a free function

Also introduces a free function for the base ShaderIO framework.

Bug: tint:1718
Change-Id: I6b1564e361101ba9e659f109104739e6e6ef4b34
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/143821
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc
index 6380c3d..bb50d8b 100644
--- a/src/tint/lang/core/ir/transform/shader_io.cc
+++ b/src/tint/lang/core/ir/transform/shader_io.cc
@@ -21,8 +21,6 @@
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/type/struct.h"
 
-TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIO);
-
 using namespace tint::builtin::fluent_types;  // NOLINT
 using namespace tint::number_suffixes;        // NOLINT
 
@@ -70,14 +68,8 @@
     return builtin::BuiltinValue::kUndefined;
 }
 
-}  // namespace
-
-ShaderIO::ShaderIO() = default;
-
-ShaderIO::~ShaderIO() = default;
-
-/// PIMPL state for the transform, for a single entry point function.
-struct ShaderIO::State {
+/// PIMPL state for the transform.
+struct State {
     /// The IR module.
     Module* ir = nullptr;
     /// The IR builder.
@@ -91,7 +83,7 @@
     Function* func = nullptr;
 
     /// The backend state object for the current entry point.
-    std::unique_ptr<ShaderIO::BackendState> backend;
+    std::unique_ptr<ShaderIOBackendState> backend;
 
     /// Constructor
     /// @param mod the module
@@ -100,7 +92,7 @@
     /// Process an entry point.
     /// @param f the original entry point function
     /// @param bs the backend state object
-    void Process(Function* f, std::unique_ptr<ShaderIO::BackendState> bs) {
+    void Process(Function* f, std::unique_ptr<ShaderIOBackendState> bs) {
         TINT_SCOPED_ASSIGNMENT(func, f);
         backend = std::move(bs);
         TINT_DEFER(backend = nullptr);
@@ -245,9 +237,11 @@
     }
 };
 
-void ShaderIO::Run(Module* ir) const {
-    ShaderIO::State state(ir);
-    for (auto* func : ir->functions) {
+}  // namespace
+
+void RunShaderIOBase(Module* module, std::function<MakeBackendStateFunc> make_backend_state) {
+    State state(module);
+    for (auto* func : module->functions) {
         // Only process entry points.
         if (func->Stage() == Function::PipelineStage::kUndefined) {
             continue;
@@ -258,11 +252,11 @@
             continue;
         }
 
-        state.Process(func, MakeBackendState(ir, func));
+        state.Process(func, make_backend_state(module, func));
     }
     state.Finalize();
 }
 
-ShaderIO::BackendState::~BackendState() = default;
+ShaderIOBackendState::~ShaderIOBackendState() = default;
 
 }  // namespace tint::ir::transform
diff --git a/src/tint/lang/core/ir/transform/shader_io.h b/src/tint/lang/core/ir/transform/shader_io.h
index 46e9018..e5b865f 100644
--- a/src/tint/lang/core/ir/transform/shader_io.h
+++ b/src/tint/lang/core/ir/transform/shader_io.h
@@ -19,104 +19,87 @@
 #include <utility>
 
 #include "src/tint/lang/core/ir/builder.h"
-#include "src/tint/lang/core/ir/transform/transform.h"
 #include "src/tint/lang/core/type/manager.h"
 
 namespace tint::ir::transform {
 
-/// ShaderIO is a transform that modifies an entry point function's parameters and return value to
-/// prepare them for backend codegen.
-class ShaderIO : public Castable<ShaderIO, Transform> {
-  public:
+/// Abstract base class for the state needed to handle IO for a particular backend target.
+struct ShaderIOBackendState {
     /// Constructor
-    ShaderIO();
+    /// @param mod the IR module
+    /// @param f the entry point function
+    ShaderIOBackendState(Module* mod, Function* f) : ir(mod), func(f) {}
+
     /// Destructor
-    ~ShaderIO() override;
+    virtual ~ShaderIOBackendState();
 
-    /// @copydoc Transform::Run
-    void Run(ir::Module* module) const override;
+    /// Add an input.
+    /// @param name the name of the input
+    /// @param type the type of the input
+    /// @param attributes the IO attributes
+    virtual void AddInput(Symbol name,
+                          const type::Type* type,
+                          type::StructMemberAttributes attributes) {
+        inputs.Push({name, type, std::move(attributes)});
+    }
 
-    /// Abstract base class for the state needed to handle IO for a particular backend target.
-    struct BackendState {
-        /// Constructor
-        /// @param mod the IR module
-        /// @param f the entry point function
-        BackendState(Module* mod, Function* f) : ir(mod), func(f) {}
+    /// Add an output.
+    /// @param name the name of the output
+    /// @param type the type of the output
+    /// @param attributes the IO attributes
+    virtual void AddOutput(Symbol name,
+                           const type::Type* type,
+                           type::StructMemberAttributes attributes) {
+        outputs.Push({name, type, std::move(attributes)});
+    }
 
-        /// Destructor
-        virtual ~BackendState();
+    /// Finalize the shader inputs and create any state needed for the new entry point function.
+    /// @returns the list of function parameters for the new entry point
+    virtual Vector<FunctionParam*, 4> FinalizeInputs() = 0;
 
-        /// Add an input.
-        /// @param name the name of the input
-        /// @param type the type of the input
-        /// @param attributes the IO attributes
-        virtual void AddInput(Symbol name,
-                              const type::Type* type,
-                              type::StructMemberAttributes attributes) {
-            inputs.Push({name, type, std::move(attributes)});
-        }
+    /// Finalize the shader outputs and create state needed for the new entry point function.
+    /// @returns the return value for the new entry point
+    virtual Value* FinalizeOutputs() = 0;
 
-        /// Add an output.
-        /// @param name the name of the output
-        /// @param type the type of the output
-        /// @param attributes the IO attributes
-        virtual void AddOutput(Symbol name,
-                               const type::Type* type,
-                               type::StructMemberAttributes attributes) {
-            outputs.Push({name, type, std::move(attributes)});
-        }
+    /// Get the value of the input at index @p idx
+    /// @param builder the IR builder for new instructions
+    /// @param idx the index of the input
+    /// @returns the value of the input
+    virtual Value* GetInput(Builder& builder, uint32_t idx) = 0;
 
-        /// Finalize the shader inputs and create any state needed for the new entry point function.
-        /// @returns the list of function parameters for the new entry point
-        virtual Vector<FunctionParam*, 4> FinalizeInputs() = 0;
-
-        /// Finalize the shader outputs and create state needed for the new entry point function.
-        /// @returns the return value for the new entry point
-        virtual Value* FinalizeOutputs() = 0;
-
-        /// Get the value of the input at index @p idx
-        /// @param builder the IR builder for new instructions
-        /// @param idx the index of the input
-        /// @returns the value of the input
-        virtual Value* GetInput(Builder& builder, uint32_t idx) = 0;
-
-        /// Set the value of the output at index @p idx
-        /// @param builder the IR builder for new instructions
-        /// @param idx the index of the output
-        /// @param value the value to set
-        virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
-
-      protected:
-        /// The IR module.
-        Module* ir = nullptr;
-
-        /// The IR builder.
-        Builder b{*ir};
-
-        /// The type manager.
-        type::Manager& ty{ir->Types()};
-
-        /// The original entry point function.
-        Function* func = nullptr;
-
-        /// The list of shader inputs.
-        Vector<type::Manager::StructMemberDesc, 4> inputs;
-
-        /// The list of shader outputs.
-        Vector<type::Manager::StructMemberDesc, 4> outputs;
-    };
+    /// Set the value of the output at index @p idx
+    /// @param builder the IR builder for new instructions
+    /// @param idx the index of the output
+    /// @param value the value to set
+    virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
 
   protected:
-    struct State;
+    /// The IR module.
+    Module* ir = nullptr;
 
-    /// Create a backend state object.
-    /// @param mod the IR module
-    /// @param func the entry point function
-    /// @returns the backend state object
-    virtual std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
-                                                                     Function* func) const = 0;
+    /// The IR builder.
+    Builder b{*ir};
+
+    /// The type manager.
+    type::Manager& ty{ir->Types()};
+
+    /// The original entry point function.
+    Function* func = nullptr;
+
+    /// The list of shader inputs.
+    Vector<type::Manager::StructMemberDesc, 4> inputs;
+
+    /// The list of shader outputs.
+    Vector<type::Manager::StructMemberDesc, 4> outputs;
 };
 
+/// The signature for a function that creates a backend state object.
+using MakeBackendStateFunc = std::unique_ptr<ShaderIOBackendState>(Module*, Function*);
+
+/// @param module the module to transform
+/// @param make_backend_state a function that creates a backend state object
+void RunShaderIOBase(Module* module, std::function<MakeBackendStateFunc> make_backend_state);
+
 }  // namespace tint::ir::transform
 
 #endif  // SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_H_
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv.cc b/src/tint/lang/core/ir/transform/shader_io_spirv.cc
index e4fbc97..a7ef0cc 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv.cc
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv.cc
@@ -19,20 +19,16 @@
 
 #include "src/tint/lang/core/ir/builder.h"
 #include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/transform/shader_io.h"
+#include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/struct.h"
 
-TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIOSpirv);
-
 using namespace tint::builtin::fluent_types;  // NOLINT
 using namespace tint::number_suffixes;        // NOLINT
 
 namespace tint::ir::transform {
 
-ShaderIOSpirv::ShaderIOSpirv() = default;
-
-ShaderIOSpirv::~ShaderIOSpirv() = default;
-
 namespace {
 
 /// PIMPL state for the parts of the shader IO transform specific to SPIR-V.
@@ -40,7 +36,7 @@
 /// output, and declare global variables for them. The wrapper entry point then loads from and
 /// stores to these variables.
 /// We also modify the type of the SampleMask builtin to be an array, as required by Vulkan.
-struct StateImpl : ShaderIO::BackendState {
+struct StateImpl : ShaderIOBackendState {
     /// The global variable for input builtins.
     Var* builtin_input_var = nullptr;
     /// The global variable for input locations.
@@ -55,8 +51,8 @@
     Vector<uint32_t, 4> output_indices;
 
     /// Constructor
-    /// @copydoc ShaderIO::BackendState::BackendState
-    using ShaderIO::BackendState::BackendState;
+    /// @copydoc ShaderIO::ShaderIOBackendState::ShaderIOBackendState
+    using ShaderIOBackendState::ShaderIOBackendState;
     /// Destructor
     ~StateImpl() override {}
 
@@ -164,9 +160,16 @@
 };
 }  // namespace
 
-std::unique_ptr<ShaderIO::BackendState> ShaderIOSpirv::MakeBackendState(Module* mod,
-                                                                        Function* func) const {
-    return std::make_unique<StateImpl>(mod, func);
+Result<SuccessType, std::string> ShaderIOSpirv(Module* ir) {
+    auto result = ValidateAndDumpIfNeeded(*ir, "ShaderIOSpirv transform");
+    if (!result) {
+        return result;
+    }
+
+    RunShaderIOBase(
+        ir, [](Module* mod, Function* func) { return std::make_unique<StateImpl>(mod, func); });
+
+    return Success;
 }
 
 }  // namespace tint::ir::transform
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv.h b/src/tint/lang/core/ir/transform/shader_io_spirv.h
index 017b09c..71dea32 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv.h
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv.h
@@ -15,24 +15,22 @@
 #ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_SPIRV_H_
 #define SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_SPIRV_H_
 
-#include "src/tint/lang/core/ir/transform/shader_io.h"
+#include <string>
 
-#include <memory>
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::ir {
+class Module;
+}
 
 namespace tint::ir::transform {
 
-/// ShaderIOSpirv is the subclass of the ShaderIO transform used for the SPIR-V backend.
-class ShaderIOSpirv final : public Castable<ShaderIOSpirv, ShaderIO> {
-  public:
-    /// Constructor
-    ShaderIOSpirv();
-    /// Destructor
-    ~ShaderIOSpirv() override;
-
-    /// @copydoc ShaderIO::MakeBackendState
-    std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
-                                                             Function* func) const override;
-};
+/// ShaderIOSpirv is a transform that modifies each entry point function's parameters and return
+/// value to prepare them for SPIR-V codegen.
+/// @param module the module to transform
+/// @returns an error string on failure
+Result<SuccessType, std::string> ShaderIOSpirv(Module* module);
 
 }  // namespace tint::ir::transform
 
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc b/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
index d2ff101..331d317 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
@@ -45,7 +45,7 @@
 
     auto* expect = src;
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -135,7 +135,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -271,7 +271,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -388,7 +388,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -438,7 +438,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -487,7 +487,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -587,7 +587,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -736,7 +736,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -825,7 +825,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
@@ -918,7 +918,7 @@
 }
 )";
 
-    Run<ShaderIOSpirv>();
+    Run(ShaderIOSpirv);
 
     EXPECT_EQ(expect, str());
 }
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index c6fcd24..f13488e 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -19,27 +19,34 @@
 #include "spirv/unified1/GLSL.std.450.h"
 #include "spirv/unified1/spirv.h"
 #include "src/tint/lang/core/constant/scalar.h"
+#include "src/tint/lang/core/constant/splat.h"
 #include "src/tint/lang/core/ir/access.h"
 #include "src/tint/lang/core/ir/binary.h"
+#include "src/tint/lang/core/ir/bitcast.h"
 #include "src/tint/lang/core/ir/block.h"
 #include "src/tint/lang/core/ir/block_param.h"
 #include "src/tint/lang/core/ir/break_if.h"
 #include "src/tint/lang/core/ir/construct.h"
 #include "src/tint/lang/core/ir/continue.h"
+#include "src/tint/lang/core/ir/convert.h"
 #include "src/tint/lang/core/ir/core_builtin_call.h"
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/exit_loop.h"
 #include "src/tint/lang/core/ir/exit_switch.h"
 #include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/intrinsic_call.h"
 #include "src/tint/lang/core/ir/let.h"
 #include "src/tint/lang/core/ir/load.h"
+#include "src/tint/lang/core/ir/load_vector_element.h"
 #include "src/tint/lang/core/ir/loop.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
 #include "src/tint/lang/core/ir/next_iteration.h"
 #include "src/tint/lang/core/ir/return.h"
 #include "src/tint/lang/core/ir/store.h"
+#include "src/tint/lang/core/ir/store_vector_element.h"
 #include "src/tint/lang/core/ir/switch.h"
+#include "src/tint/lang/core/ir/swizzle.h"
 #include "src/tint/lang/core/ir/terminate_invocation.h"
 #include "src/tint/lang/core/ir/terminator.h"
 #include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
@@ -105,8 +112,8 @@
     ir::transform::ExpandImplicitSplats{}.Run(module);
     ir::transform::HandleMatrixArithmetic{}.Run(module);
     ir::transform::MergeReturn{}.Run(module);
-    ir::transform::ShaderIOSpirv{}.Run(module);
 
+    RUN_TRANSFORM(ShaderIOSpirv);
     RUN_TRANSFORM(Std140);
     RUN_TRANSFORM(VarForDynamicIndex);