[ir] Convert ShaderIOSpirv to a free function
Also introduces a free function for the base ShaderIO framework.
Bug: tint:1718
Change-Id: I6b1564e361101ba9e659f109104739e6e6ef4b34
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/143821
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc
index 6380c3d..bb50d8b 100644
--- a/src/tint/lang/core/ir/transform/shader_io.cc
+++ b/src/tint/lang/core/ir/transform/shader_io.cc
@@ -21,8 +21,6 @@
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/type/struct.h"
-TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIO);
-
using namespace tint::builtin::fluent_types; // NOLINT
using namespace tint::number_suffixes; // NOLINT
@@ -70,14 +68,8 @@
return builtin::BuiltinValue::kUndefined;
}
-} // namespace
-
-ShaderIO::ShaderIO() = default;
-
-ShaderIO::~ShaderIO() = default;
-
-/// PIMPL state for the transform, for a single entry point function.
-struct ShaderIO::State {
+/// PIMPL state for the transform.
+struct State {
/// The IR module.
Module* ir = nullptr;
/// The IR builder.
@@ -91,7 +83,7 @@
Function* func = nullptr;
/// The backend state object for the current entry point.
- std::unique_ptr<ShaderIO::BackendState> backend;
+ std::unique_ptr<ShaderIOBackendState> backend;
/// Constructor
/// @param mod the module
@@ -100,7 +92,7 @@
/// Process an entry point.
/// @param f the original entry point function
/// @param bs the backend state object
- void Process(Function* f, std::unique_ptr<ShaderIO::BackendState> bs) {
+ void Process(Function* f, std::unique_ptr<ShaderIOBackendState> bs) {
TINT_SCOPED_ASSIGNMENT(func, f);
backend = std::move(bs);
TINT_DEFER(backend = nullptr);
@@ -245,9 +237,11 @@
}
};
-void ShaderIO::Run(Module* ir) const {
- ShaderIO::State state(ir);
- for (auto* func : ir->functions) {
+} // namespace
+
+void RunShaderIOBase(Module* module, std::function<MakeBackendStateFunc> make_backend_state) {
+ State state(module);
+ for (auto* func : module->functions) {
// Only process entry points.
if (func->Stage() == Function::PipelineStage::kUndefined) {
continue;
@@ -258,11 +252,11 @@
continue;
}
- state.Process(func, MakeBackendState(ir, func));
+ state.Process(func, make_backend_state(module, func));
}
state.Finalize();
}
-ShaderIO::BackendState::~BackendState() = default;
+ShaderIOBackendState::~ShaderIOBackendState() = default;
} // namespace tint::ir::transform
diff --git a/src/tint/lang/core/ir/transform/shader_io.h b/src/tint/lang/core/ir/transform/shader_io.h
index 46e9018..e5b865f 100644
--- a/src/tint/lang/core/ir/transform/shader_io.h
+++ b/src/tint/lang/core/ir/transform/shader_io.h
@@ -19,104 +19,87 @@
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
-#include "src/tint/lang/core/ir/transform/transform.h"
#include "src/tint/lang/core/type/manager.h"
namespace tint::ir::transform {
-/// ShaderIO is a transform that modifies an entry point function's parameters and return value to
-/// prepare them for backend codegen.
-class ShaderIO : public Castable<ShaderIO, Transform> {
- public:
+/// Abstract base class for the state needed to handle IO for a particular backend target.
+struct ShaderIOBackendState {
/// Constructor
- ShaderIO();
+ /// @param mod the IR module
+ /// @param f the entry point function
+ ShaderIOBackendState(Module* mod, Function* f) : ir(mod), func(f) {}
+
/// Destructor
- ~ShaderIO() override;
+ virtual ~ShaderIOBackendState();
- /// @copydoc Transform::Run
- void Run(ir::Module* module) const override;
+ /// Add an input.
+ /// @param name the name of the input
+ /// @param type the type of the input
+ /// @param attributes the IO attributes
+ virtual void AddInput(Symbol name,
+ const type::Type* type,
+ type::StructMemberAttributes attributes) {
+ inputs.Push({name, type, std::move(attributes)});
+ }
- /// Abstract base class for the state needed to handle IO for a particular backend target.
- struct BackendState {
- /// Constructor
- /// @param mod the IR module
- /// @param f the entry point function
- BackendState(Module* mod, Function* f) : ir(mod), func(f) {}
+ /// Add an output.
+ /// @param name the name of the output
+ /// @param type the type of the output
+ /// @param attributes the IO attributes
+ virtual void AddOutput(Symbol name,
+ const type::Type* type,
+ type::StructMemberAttributes attributes) {
+ outputs.Push({name, type, std::move(attributes)});
+ }
- /// Destructor
- virtual ~BackendState();
+ /// Finalize the shader inputs and create any state needed for the new entry point function.
+ /// @returns the list of function parameters for the new entry point
+ virtual Vector<FunctionParam*, 4> FinalizeInputs() = 0;
- /// Add an input.
- /// @param name the name of the input
- /// @param type the type of the input
- /// @param attributes the IO attributes
- virtual void AddInput(Symbol name,
- const type::Type* type,
- type::StructMemberAttributes attributes) {
- inputs.Push({name, type, std::move(attributes)});
- }
+ /// Finalize the shader outputs and create state needed for the new entry point function.
+ /// @returns the return value for the new entry point
+ virtual Value* FinalizeOutputs() = 0;
- /// Add an output.
- /// @param name the name of the output
- /// @param type the type of the output
- /// @param attributes the IO attributes
- virtual void AddOutput(Symbol name,
- const type::Type* type,
- type::StructMemberAttributes attributes) {
- outputs.Push({name, type, std::move(attributes)});
- }
+ /// Get the value of the input at index @p idx
+ /// @param builder the IR builder for new instructions
+ /// @param idx the index of the input
+ /// @returns the value of the input
+ virtual Value* GetInput(Builder& builder, uint32_t idx) = 0;
- /// Finalize the shader inputs and create any state needed for the new entry point function.
- /// @returns the list of function parameters for the new entry point
- virtual Vector<FunctionParam*, 4> FinalizeInputs() = 0;
-
- /// Finalize the shader outputs and create state needed for the new entry point function.
- /// @returns the return value for the new entry point
- virtual Value* FinalizeOutputs() = 0;
-
- /// Get the value of the input at index @p idx
- /// @param builder the IR builder for new instructions
- /// @param idx the index of the input
- /// @returns the value of the input
- virtual Value* GetInput(Builder& builder, uint32_t idx) = 0;
-
- /// Set the value of the output at index @p idx
- /// @param builder the IR builder for new instructions
- /// @param idx the index of the output
- /// @param value the value to set
- virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
-
- protected:
- /// The IR module.
- Module* ir = nullptr;
-
- /// The IR builder.
- Builder b{*ir};
-
- /// The type manager.
- type::Manager& ty{ir->Types()};
-
- /// The original entry point function.
- Function* func = nullptr;
-
- /// The list of shader inputs.
- Vector<type::Manager::StructMemberDesc, 4> inputs;
-
- /// The list of shader outputs.
- Vector<type::Manager::StructMemberDesc, 4> outputs;
- };
+ /// Set the value of the output at index @p idx
+ /// @param builder the IR builder for new instructions
+ /// @param idx the index of the output
+ /// @param value the value to set
+ virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
protected:
- struct State;
+ /// The IR module.
+ Module* ir = nullptr;
- /// Create a backend state object.
- /// @param mod the IR module
- /// @param func the entry point function
- /// @returns the backend state object
- virtual std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
- Function* func) const = 0;
+ /// The IR builder.
+ Builder b{*ir};
+
+ /// The type manager.
+ type::Manager& ty{ir->Types()};
+
+ /// The original entry point function.
+ Function* func = nullptr;
+
+ /// The list of shader inputs.
+ Vector<type::Manager::StructMemberDesc, 4> inputs;
+
+ /// The list of shader outputs.
+ Vector<type::Manager::StructMemberDesc, 4> outputs;
};
+/// The signature for a function that creates a backend state object.
+using MakeBackendStateFunc = std::unique_ptr<ShaderIOBackendState>(Module*, Function*);
+
+/// @param module the module to transform
+/// @param make_backend_state a function that creates a backend state object
+void RunShaderIOBase(Module* module, std::function<MakeBackendStateFunc> make_backend_state);
+
} // namespace tint::ir::transform
#endif // SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_H_
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv.cc b/src/tint/lang/core/ir/transform/shader_io_spirv.cc
index e4fbc97..a7ef0cc 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv.cc
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv.cc
@@ -19,20 +19,16 @@
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/transform/shader_io.h"
+#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/struct.h"
-TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::ShaderIOSpirv);
-
using namespace tint::builtin::fluent_types; // NOLINT
using namespace tint::number_suffixes; // NOLINT
namespace tint::ir::transform {
-ShaderIOSpirv::ShaderIOSpirv() = default;
-
-ShaderIOSpirv::~ShaderIOSpirv() = default;
-
namespace {
/// PIMPL state for the parts of the shader IO transform specific to SPIR-V.
@@ -40,7 +36,7 @@
/// output, and declare global variables for them. The wrapper entry point then loads from and
/// stores to these variables.
/// We also modify the type of the SampleMask builtin to be an array, as required by Vulkan.
-struct StateImpl : ShaderIO::BackendState {
+struct StateImpl : ShaderIOBackendState {
/// The global variable for input builtins.
Var* builtin_input_var = nullptr;
/// The global variable for input locations.
@@ -55,8 +51,8 @@
Vector<uint32_t, 4> output_indices;
/// Constructor
- /// @copydoc ShaderIO::BackendState::BackendState
- using ShaderIO::BackendState::BackendState;
+ /// @copydoc ShaderIO::ShaderIOBackendState::ShaderIOBackendState
+ using ShaderIOBackendState::ShaderIOBackendState;
/// Destructor
~StateImpl() override {}
@@ -164,9 +160,16 @@
};
} // namespace
-std::unique_ptr<ShaderIO::BackendState> ShaderIOSpirv::MakeBackendState(Module* mod,
- Function* func) const {
- return std::make_unique<StateImpl>(mod, func);
+Result<SuccessType, std::string> ShaderIOSpirv(Module* ir) {
+ auto result = ValidateAndDumpIfNeeded(*ir, "ShaderIOSpirv transform");
+ if (!result) {
+ return result;
+ }
+
+ RunShaderIOBase(
+ ir, [](Module* mod, Function* func) { return std::make_unique<StateImpl>(mod, func); });
+
+ return Success;
}
} // namespace tint::ir::transform
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv.h b/src/tint/lang/core/ir/transform/shader_io_spirv.h
index 017b09c..71dea32 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv.h
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv.h
@@ -15,24 +15,22 @@
#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_SPIRV_H_
#define SRC_TINT_LANG_CORE_IR_TRANSFORM_SHADER_IO_SPIRV_H_
-#include "src/tint/lang/core/ir/transform/shader_io.h"
+#include <string>
-#include <memory>
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::ir {
+class Module;
+}
namespace tint::ir::transform {
-/// ShaderIOSpirv is the subclass of the ShaderIO transform used for the SPIR-V backend.
-class ShaderIOSpirv final : public Castable<ShaderIOSpirv, ShaderIO> {
- public:
- /// Constructor
- ShaderIOSpirv();
- /// Destructor
- ~ShaderIOSpirv() override;
-
- /// @copydoc ShaderIO::MakeBackendState
- std::unique_ptr<ShaderIO::BackendState> MakeBackendState(Module* mod,
- Function* func) const override;
-};
+/// ShaderIOSpirv is a transform that modifies each entry point function's parameters and return
+/// value to prepare them for SPIR-V codegen.
+/// @param module the module to transform
+/// @returns an error string on failure
+Result<SuccessType, std::string> ShaderIOSpirv(Module* module);
} // namespace tint::ir::transform
diff --git a/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc b/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
index d2ff101..331d317 100644
--- a/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
+++ b/src/tint/lang/core/ir/transform/shader_io_spirv_test.cc
@@ -45,7 +45,7 @@
auto* expect = src;
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -135,7 +135,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -271,7 +271,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -388,7 +388,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -438,7 +438,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -487,7 +487,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -587,7 +587,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -736,7 +736,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -825,7 +825,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
@@ -918,7 +918,7 @@
}
)";
- Run<ShaderIOSpirv>();
+ Run(ShaderIOSpirv);
EXPECT_EQ(expect, str());
}
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index c6fcd24..f13488e 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -19,27 +19,34 @@
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
#include "src/tint/lang/core/constant/scalar.h"
+#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
+#include "src/tint/lang/core/ir/bitcast.h"
#include "src/tint/lang/core/ir/block.h"
#include "src/tint/lang/core/ir/block_param.h"
#include "src/tint/lang/core/ir/break_if.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
+#include "src/tint/lang/core/ir/convert.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/exit_if.h"
#include "src/tint/lang/core/ir/exit_loop.h"
#include "src/tint/lang/core/ir/exit_switch.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/intrinsic_call.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
+#include "src/tint/lang/core/ir/load_vector_element.h"
#include "src/tint/lang/core/ir/loop.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/next_iteration.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/store.h"
+#include "src/tint/lang/core/ir/store_vector_element.h"
#include "src/tint/lang/core/ir/switch.h"
+#include "src/tint/lang/core/ir/swizzle.h"
#include "src/tint/lang/core/ir/terminate_invocation.h"
#include "src/tint/lang/core/ir/terminator.h"
#include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
@@ -105,8 +112,8 @@
ir::transform::ExpandImplicitSplats{}.Run(module);
ir::transform::HandleMatrixArithmetic{}.Run(module);
ir::transform::MergeReturn{}.Run(module);
- ir::transform::ShaderIOSpirv{}.Run(module);
+ RUN_TRANSFORM(ShaderIOSpirv);
RUN_TRANSFORM(Std140);
RUN_TRANSFORM(VarForDynamicIndex);