[ir][spirv-writer] Add builder to create constants
The SPIRV-Writer has to create a few constants, currently it uses the IR
directly and the `constant_values` constant::Manager. This CL adds the
`Builder` into the SPIRV-Writer and then uses the buider helper methods
to create the constants.
Bug: tint:1906
Change-Id: I6722dfeb6da2c24cb86d3dceb21443bf23ee17db
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144321
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index ef086ee..5b81aba3 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -150,7 +150,7 @@
} // namespace
Printer::Printer(ir::Module* module, bool zero_init_workgroup_mem)
- : ir_(module), zero_init_workgroup_memory_(zero_init_workgroup_mem) {}
+ : ir_(module), b_(*module), zero_init_workgroup_memory_(zero_init_workgroup_mem) {}
Result<std::vector<uint32_t>, std::string> Printer::Generate() {
auto valid = ir::ValidateAndDumpIfNeeded(*ir_, "SPIR-V writer");
@@ -355,7 +355,7 @@
},
[&](const type::Array* arr) {
if (arr->ConstantCount()) {
- auto* count = ir_->constant_values.Get(u32(arr->ConstantCount().value()));
+ auto* count = b_.ConstantValue(u32(arr->ConstantCount().value()));
module_.PushType(spv::Op::OpTypeArray,
{id, Type(arr->ElemType()), Constant(count)});
} else {
@@ -1341,11 +1341,11 @@
case builtin::Function::kStorageBarrier:
op = spv::Op::OpControlBarrier;
operands.clear();
- operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
- operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
operands.push_back(
- Constant(ir_->constant_values.Get(u32(spv::MemorySemanticsMask::UniformMemory |
- spv::MemorySemanticsMask::AcquireRelease))));
+ Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::UniformMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
break;
case builtin::Function::kSubgroupBallot:
module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
@@ -1391,11 +1391,11 @@
case builtin::Function::kWorkgroupBarrier:
op = spv::Op::OpControlBarrier;
operands.clear();
- operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
- operands.push_back(Constant(ir_->constant_values.Get(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
operands.push_back(
- Constant(ir_->constant_values.Get(u32(spv::MemorySemanticsMask::WorkgroupMemory |
- spv::MemorySemanticsMask::AcquireRelease))));
+ Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::WorkgroupMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
break;
default:
TINT_ICE() << "unimplemented builtin function: " << builtin->Func();
@@ -1471,37 +1471,37 @@
operands.push_back(ConstantNull(arg_ty));
} else if (arg_ty->is_bool_scalar_or_vector()) {
// Select between constant one and zero, splatting them to vectors if necessary.
- const constant::Value* one = nullptr;
- const constant::Value* zero = nullptr;
+ ir::Constant* one = nullptr;
+ ir::Constant* zero = nullptr;
Switch(
res_ty->DeepestElement(), //
[&](const type::F32*) {
- one = ir_->constant_values.Get(1_f);
- zero = ir_->constant_values.Get(0_f);
+ one = b_.Constant(1_f);
+ zero = b_.Constant(0_f);
},
[&](const type::F16*) {
- one = ir_->constant_values.Get(1_h);
- zero = ir_->constant_values.Get(0_h);
+ one = b_.Constant(1_h);
+ zero = b_.Constant(0_h);
},
[&](const type::I32*) {
- one = ir_->constant_values.Get(1_i);
- zero = ir_->constant_values.Get(0_i);
+ one = b_.Constant(1_i);
+ zero = b_.Constant(0_i);
},
[&](const type::U32*) {
- one = ir_->constant_values.Get(1_u);
- zero = ir_->constant_values.Get(0_u);
+ one = b_.Constant(1_u);
+ zero = b_.Constant(0_u);
});
TINT_ASSERT_OR_RETURN(one && zero);
if (auto* vec = res_ty->As<type::Vector>()) {
// Splat the scalars into vectors.
- one = ir_->constant_values.Splat(vec, one, vec->Width());
- zero = ir_->constant_values.Splat(vec, zero, vec->Width());
+ one = b_.Splat(vec, one, vec->Width());
+ zero = b_.Splat(vec, zero, vec->Width());
}
op = spv::Op::OpSelect;
- operands.push_back(Constant(one));
- operands.push_back(Constant(zero));
+ operands.push_back(Constant(b_.ConstantValue(one)));
+ operands.push_back(Constant(b_.ConstantValue(zero)));
} else {
TINT_ICE() << "unhandled convert instruction";
}
diff --git a/src/tint/lang/spirv/writer/printer/printer.h b/src/tint/lang/spirv/writer/printer/printer.h
index 7c292a7..5f1c085 100644
--- a/src/tint/lang/spirv/writer/printer/printer.h
+++ b/src/tint/lang/spirv/writer/printer/printer.h
@@ -22,6 +22,7 @@
#include "src/tint/lang/core/builtin/builtin_value.h"
#include "src/tint/lang/core/builtin/texel_format.h"
#include "src/tint/lang/core/constant/value.h"
+#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/spirv/writer/common/binary_writer.h"
#include "src/tint/lang/spirv/writer/common/function.h"
@@ -266,6 +267,7 @@
void EmitExitPhis(ir::ControlInstruction* inst);
ir::Module* ir_;
+ ir::Builder b_;
writer::Module module_;
BinaryWriter writer_;