Cleanup transform usage
Use tint::transform::DataMap for inputs as well as outputs.
This allows tint to nest transforms inside each other (e.g. embedding
transforms inside sanitizers), and still having a consistent way to pass
data in and out of these transforms, regardless of nesting depth.
Transforms can also now be fully pre-built and used multiple times as
there is no state held by the transform itself.
Bug: tint:389
Change-Id: If1616c77f2776be449021a32f4a6b0b89159aa2a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/48060
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 90115f0..a9cceac 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -1110,11 +1110,13 @@
parseResult->tintSource = std::move(tintSource);
} else {
tint::transform::Manager transformManager;
- transformManager.append(
- std::make_unique<tint::transform::EmitVertexPointSize>());
- transformManager.append(std::make_unique<tint::transform::Spirv>());
- DAWN_TRY_ASSIGN(program,
- RunTransforms(&transformManager, &program, outMessages));
+ transformManager.Add<tint::transform::EmitVertexPointSize>();
+ transformManager.Add<tint::transform::Spirv>();
+
+ tint::transform::DataMap transformInputs;
+
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
+ transformInputs, nullptr, outMessages));
std::vector<uint32_t> spirv;
DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
@@ -1144,8 +1146,10 @@
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program,
+ const tint::transform::DataMap& inputs,
+ tint::transform::DataMap* outputs,
OwnedCompilationMessages* outMessages) {
- tint::transform::Transform::Output output = transform->Run(program);
+ tint::transform::Output output = transform->Run(program, inputs);
if (outMessages != nullptr) {
outMessages->AddMessages(output.program.Diagnostics());
}
@@ -1153,13 +1157,16 @@
std::string err = "Tint program failure: " + output.program.Diagnostics().str();
return DAWN_VALIDATION_ERROR(err.c_str());
}
+ if (outputs != nullptr) {
+ *outputs = std::move(output.data);
+ }
return std::move(output.program);
}
- std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
- const VertexState& vertexState,
- const std::string& entryPoint,
- BindGroupIndex pullingBufferBindingSet) {
+ void AddVertexPullingTransformConfig(const VertexState& vertexState,
+ const std::string& entryPoint,
+ BindGroupIndex pullingBufferBindingSet,
+ tint::transform::DataMap* transformInputs) {
tint::transform::VertexPulling::Config cfg;
cfg.entry_point_name = entryPoint;
cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
@@ -1181,7 +1188,7 @@
cfg.vertex_state.push_back(std::move(layout));
}
- return std::make_unique<tint::transform::VertexPulling>(cfg);
+ transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
}
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
@@ -1314,19 +1321,23 @@
errorStream << "Tint vertex pulling failure:" << std::endl;
tint::transform::Manager transformManager;
- transformManager.append(
- MakeVertexPullingTransform(vertexState, entryPoint, pullingBufferBindingSet));
- transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
- transformManager.append(std::make_unique<tint::transform::Spirv>());
+ transformManager.Add<tint::transform::VertexPulling>();
+ transformManager.Add<tint::transform::EmitVertexPointSize>();
+ transformManager.Add<tint::transform::Spirv>();
if (GetDevice()->IsRobustnessEnabled()) {
- transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
+ transformManager.Add<tint::transform::BoundArrayAccessors>();
}
+ tint::transform::DataMap transformInputs;
+ AddVertexPullingTransformConfig(vertexState, entryPoint, pullingBufferBindingSet,
+ &transformInputs);
+
// A nullptr is passed in for the CompilationMessages here since this method is called
- // during RenderPipeline creation, by which point the shader module's CompilationInfo may
- // have already been queried.
+ // during RenderPipeline creation, by which point the shader module's CompilationInfo
+ // may have already been queried.
tint::Program program;
- DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, nullptr));
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, programIn, transformInputs,
+ nullptr, nullptr));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 556d604..2534792 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -37,6 +37,7 @@
class Program;
namespace transform {
+ class DataMap;
class Transform;
class VertexPulling;
} // namespace transform
@@ -88,12 +89,15 @@
const PipelineLayoutBase* layout);
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
const tint::Program* program,
+ const tint::transform::DataMap& inputs,
+ tint::transform::DataMap* outputs,
OwnedCompilationMessages* messages);
- std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
- const VertexState& vertexState,
- const std::string& entryPoint,
- BindGroupIndex pullingBufferBindingSet);
+ /// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
+ void AddVertexPullingTransformConfig(const VertexState& vertexState,
+ const std::string& entryPoint,
+ BindGroupIndex pullingBufferBindingSet,
+ tint::transform::DataMap* transformInputs);
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
@@ -173,7 +177,7 @@
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const;
- OwnedCompilationMessages* CompilationMessages() {
+ OwnedCompilationMessages* GetCompilationMessages() {
return mCompilationMessages.get();
}
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 8b14210..57b1b51 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -194,7 +194,7 @@
SingleShaderStage stage,
PipelineLayout* layout,
std::string* remappedEntryPointName,
- FirstOffsetInfo* firstOffsetInfo) const {
+ FirstOffsetInfo* firstOffsetInfo) {
ASSERT(!IsError());
ScopedTintICEHandler scopedICEHandler(GetDevice());
@@ -245,30 +245,28 @@
errorStream << "Tint HLSL failure:" << std::endl;
tint::transform::Manager transformManager;
- transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
- if (stage == SingleShaderStage::Vertex) {
- transformManager.append(std::make_unique<tint::transform::FirstIndexOffset>(
- layout->GetFirstIndexOffsetShaderRegister(),
- layout->GetFirstIndexOffsetRegisterSpace()));
- }
- transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
- transformManager.append(std::make_unique<tint::transform::Renamer>());
- transformManager.append(std::make_unique<tint::transform::Hlsl>());
-
tint::transform::DataMap transformInputs;
+
+ transformManager.Add<tint::transform::BoundArrayAccessors>();
+ if (stage == SingleShaderStage::Vertex) {
+ transformManager.Add<tint::transform::FirstIndexOffset>();
+ transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
+ layout->GetFirstIndexOffsetShaderRegister(),
+ layout->GetFirstIndexOffsetRegisterSpace());
+ }
+ transformManager.Add<tint::transform::BindingRemapper>();
+ transformManager.Add<tint::transform::Renamer>();
+ transformManager.Add<tint::transform::Hlsl>();
+
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls));
- tint::transform::Transform::Output output =
- transformManager.Run(GetTintProgram(), transformInputs);
- const tint::Program& program = output.program;
- if (!program.IsValid()) {
- errorStream << "Tint program transform error: " << program.Diagnostics().str()
- << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
+ tint::Program program;
+ tint::transform::DataMap transformOutputs;
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+ &transformOutputs, nullptr));
- if (auto* data = output.data.Get<tint::transform::FirstIndexOffset::Data>()) {
+ if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
firstOffsetInfo->usesVertexIndex = data->has_vertex_index;
if (firstOffsetInfo->usesVertexIndex) {
firstOffsetInfo->vertexIndexOffset = data->first_vertex_offset;
@@ -279,7 +277,7 @@
}
}
- if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
+ if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 98eed9e..d85796d 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -63,7 +63,7 @@
SingleShaderStage stage,
PipelineLayout* layout,
std::string* remappedEntryPointName,
- FirstOffsetInfo* firstOffsetInfo) const;
+ FirstOffsetInfo* firstOffsetInfo);
ResultOrError<std::string> TranslateToHLSLWithSPIRVCross(const char* entryPointName,
SingleShaderStage stage,
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index d302ce7..f88dcc7 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -70,10 +70,12 @@
errorStream << "Tint MSL failure:" << std::endl;
tint::transform::Manager transformManager;
+ tint::transform::DataMap transformInputs;
+
if (stage == SingleShaderStage::Vertex &&
GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
- transformManager.append(
- MakeVertexPullingTransform(*vertexState, entryPointName, kPullingBufferBindingSet));
+ AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet,
+ &transformInputs);
for (VertexBufferSlot slot :
IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@@ -83,20 +85,16 @@
// this MSL buffer index.
}
}
- transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
- transformManager.append(std::make_unique<tint::transform::Renamer>());
- transformManager.append(std::make_unique<tint::transform::Msl>());
+ transformManager.Add<tint::transform::BoundArrayAccessors>();
+ transformManager.Add<tint::transform::Renamer>();
+ transformManager.Add<tint::transform::Msl>();
- tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
+ tint::Program program;
+ tint::transform::DataMap transformOutputs;
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+ &transformOutputs, nullptr));
- tint::Program& program = output.program;
- if (!program.IsValid()) {
- errorStream << "Tint program transform error: " << program.Diagnostics().str()
- << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- if (auto* data = output.data.Get<tint::transform::Renamer::Data>()) {
+ if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName);
if (it == data->remappings.end()) {
return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 0bd1895..8d8c311 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -87,9 +87,12 @@
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::Spirv>());
+ tint::transform::DataMap transformInputs;
+
tint::Program program;
- DAWN_TRY_ASSIGN(
- program, RunTransforms(&transformManager, GetTintProgram(), CompilationMessages()));
+ DAWN_TRY_ASSIGN(program,
+ RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+ nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index b8a2b23..9f4da48 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -55,14 +55,16 @@
errorStream << "Tint SPIR-V writer failure:" << std::endl;
tint::transform::Manager transformManager;
- transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
- transformManager.append(std::make_unique<tint::transform::EmitVertexPointSize>());
- transformManager.append(std::make_unique<tint::transform::Spirv>());
+ transformManager.Add<tint::transform::BoundArrayAccessors>();
+ transformManager.Add<tint::transform::EmitVertexPointSize>();
+ transformManager.Add<tint::transform::Spirv>();
+
+ tint::transform::DataMap transformInputs;
tint::Program program;
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, parseResult->tintProgram.get(),
- CompilationMessages()));
+ transformInputs, nullptr, GetCompilationMessages()));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
@@ -166,15 +168,10 @@
tint::transform::DataMap transformInputs;
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls));
- tint::transform::Transform::Output output =
- transformManager.Run(GetTintProgram(), transformInputs);
- const tint::Program& program = output.program;
- if (!program.IsValid()) {
- errorStream << "Tint program transform error: " << program.Diagnostics().str()
- << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
+ tint::Program program;
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
+ nullptr, nullptr));
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
index 30754bb..802903f 100644
--- a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -160,7 +160,7 @@
}
// Tests that shader module compilation messages can be queried.
-TEST_F(ShaderModuleValidationTest, CompilationMessages) {
+TEST_F(ShaderModuleValidationTest, GetCompilationMessages) {
// This test works assuming ShaderModule is backed by a dawn_native::ShaderModuleBase, which
// is not the case on the wire.
DAWN_SKIP_TEST_IF(UsesWire());
@@ -172,12 +172,11 @@
dawn_native::ShaderModuleBase* shaderModuleBase =
reinterpret_cast<dawn_native::ShaderModuleBase*>(shaderModule.Get());
- shaderModuleBase->CompilationMessages()->ClearMessages();
- shaderModuleBase->CompilationMessages()->AddMessage("Info Message");
- shaderModuleBase->CompilationMessages()->AddMessage("Warning Message",
- wgpu::CompilationMessageType::Warning);
- shaderModuleBase->CompilationMessages()->AddMessage("Error Message",
- wgpu::CompilationMessageType::Error, 3, 4);
+ dawn_native::OwnedCompilationMessages* messages = shaderModuleBase->GetCompilationMessages();
+ messages->ClearMessages();
+ messages->AddMessage("Info Message");
+ messages->AddMessage("Warning Message", wgpu::CompilationMessageType::Warning);
+ messages->AddMessage("Error Message", wgpu::CompilationMessageType::Error, 3, 4);
auto callback = [](WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info,
void* userdata) {