[spirv-writer] Add support for derivatives.
This CL adds support for generating the various dpdx, dpdy and fwidth
instructions.
Bug: tint:5
Change-Id: I6d12c738b93931d1e740659d9c1871892b801f71
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22625
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/intrinsic.cc b/src/ast/intrinsic.cc
index 30311f5..73cefac 100644
--- a/src/ast/intrinsic.cc
+++ b/src/ast/intrinsic.cc
@@ -18,10 +18,18 @@
namespace ast {
namespace intrinsic {
+bool IsCoarseDerivative(const std::string& name) {
+ return name == "dpdx_coarse" || name == "dpdy_coarse" ||
+ name == "fwidth_coarse";
+}
+
+bool IsFineDerivative(const std::string& name) {
+ return name == "dpdx_fine" || name == "dpdy_fine" || name == "fwidth_fine";
+}
+
bool IsDerivative(const std::string& name) {
- return name == "dpdx" || name == "dpdx_fine" || name == "dpdx_coarse" ||
- name == "dpdy" || name == "dpdy_fine" || name == "dpdy_coarse" ||
- name == "fwidth" || name == "fwidth_fine" || name == "fwidth_coarse";
+ return name == "dpdx" || name == "dpdy" || name == "fwidth" ||
+ IsCoarseDerivative(name) || IsFineDerivative(name);
}
bool IsFloatClassificationIntrinsic(const std::string& name) {
diff --git a/src/ast/intrinsic.h b/src/ast/intrinsic.h
index f827b29..b11a73c 100644
--- a/src/ast/intrinsic.h
+++ b/src/ast/intrinsic.h
@@ -21,7 +21,17 @@
namespace ast {
namespace intrinsic {
-/// Determine if the given |name | is a derivative intrinsic
+/// Determines if the given |name| is a coarse derivative
+/// @param name the name to check
+/// @returns true if the given derivative is coarse.
+bool IsCoarseDerivative(const std::string& name);
+
+/// Determines if the given |name| is a fine derivative
+/// @param name the name to check
+/// @returns true if the given derivative is fine.
+bool IsFineDerivative(const std::string& name);
+
+/// Determine if the given |name| is a derivative intrinsic
/// @param name the name to check
/// @returns true if the given |name| is a derivative intrinsic
bool IsDerivative(const std::string& name);
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index aeb4b94..f2b220c 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -138,11 +138,11 @@
Builder::~Builder() = default;
bool Builder::Build() {
- push_preamble(spv::Op::OpCapability, {Operand::Int(SpvCapabilityShader)});
+ push_capability(spv::Op::OpCapability, {Operand::Int(SpvCapabilityShader)});
// TODO(dneto): Stop using the Vulkan memory model. crbug.com/tint/63
- push_preamble(spv::Op::OpCapability,
- {Operand::Int(SpvCapabilityVulkanMemoryModel)});
+ push_capability(spv::Op::OpCapability,
+ {Operand::Int(SpvCapabilityVulkanMemoryModel)});
push_preamble(spv::Op::OpExtension,
{Operand::String("SPV_KHR_vulkan_memory_model")});
@@ -188,6 +188,7 @@
// The 5 covers the magic, version, generator, id bound and reserved.
uint32_t size = 5;
+ size += size_of(capabilities_);
size += size_of(preamble_);
size += size_of(debug_);
size += size_of(annotations_);
@@ -200,6 +201,9 @@
}
void Builder::iterate(std::function<void(const Instruction&)> cb) const {
+ for (const auto& inst : capabilities_) {
+ cb(inst);
+ }
for (const auto& inst : preamble_) {
cb(inst);
}
@@ -1245,6 +1249,12 @@
params.push_back(Operand::Int(val_id));
}
+ if (ast::intrinsic::IsFineDerivative(name) ||
+ ast::intrinsic::IsCoarseDerivative(name)) {
+ push_capability(spv::Op::OpCapability,
+ {Operand::Int(SpvCapabilityDerivativeControl)});
+ }
+
spv::Op op = spv::Op::OpNop;
if (name == "any") {
op = spv::Op::OpAny;
@@ -1252,6 +1262,24 @@
op = spv::Op::OpAll;
} else if (name == "dot") {
op = spv::Op::OpDot;
+ } else if (name == "dpdx") {
+ op = spv::Op::OpDPdx;
+ } else if (name == "dpdx_coarse") {
+ op = spv::Op::OpDPdxCoarse;
+ } else if (name == "dpdx_fine") {
+ op = spv::Op::OpDPdxFine;
+ } else if (name == "dpdy") {
+ op = spv::Op::OpDPdy;
+ } else if (name == "dpdy_coarse") {
+ op = spv::Op::OpDPdyCoarse;
+ } else if (name == "dpdy_fine") {
+ op = spv::Op::OpDPdyFine;
+ } else if (name == "fwidth") {
+ op = spv::Op::OpFwidth;
+ } else if (name == "fwidth_coarse") {
+ op = spv::Op::OpFwidthCoarse;
+ } else if (name == "fwidth_fine") {
+ op = spv::Op::OpFwidthFine;
} else if (name == "is_inf") {
op = spv::Op::OpIsInf;
} else if (name == "is_nan") {
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 5a43021..b8ac524 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -119,6 +119,14 @@
/// @param cb the callback to execute
void iterate(std::function<void(const Instruction&)> cb) const;
+ /// Adds an instruction to the list of capabilities
+ /// @param op the op to set
+ /// @param operands the operands for the instruction
+ void push_capability(spv::Op op, const std::vector<Operand>& operands) {
+ capabilities_.push_back(Instruction{op, operands});
+ }
+ /// @returns the capabilities
+ const std::vector<Instruction>& capabilities() const { return capabilities_; }
/// Adds an instruction to the preamble
/// @param op the op to set
/// @param operands the operands for the instruction
@@ -386,6 +394,7 @@
ast::Module* mod_;
std::string error_;
uint32_t next_id_ = 1;
+ std::vector<Instruction> capabilities_;
std::vector<Instruction> preamble_;
std::vector<Instruction> debug_;
std::vector<Instruction> types_;
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index 468ead6..34d434e 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -208,6 +208,102 @@
)");
}
+using IntrinsicDeriveTest = testing::TestWithParam<IntrinsicData>;
+TEST_P(IntrinsicDeriveTest, Call_Derivative_Scalar) {
+ auto param = GetParam();
+
+ ast::type::F32Type f32;
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &f32);
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
+ ast::CallExpression expr(
+ std::make_unique<ast::IdentifierExpression>(param.name),
+ std::move(params));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateCallExpression(&expr), 5u) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%6 = OpLoad %3 %1
+%5 = )" + param.op +
+ " %3 %6\n");
+}
+
+TEST_P(IntrinsicDeriveTest, Call_Derivative_Vector) {
+ auto param = GetParam();
+
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &vec3);
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
+ ast::CallExpression expr(
+ std::make_unique<ast::IdentifierExpression>(param.name),
+ std::move(params));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateCallExpression(&expr), 6u) << b.error();
+
+ if (param.name != "dpdx" && param.name != "dpdy" && param.name != "fwidth") {
+ EXPECT_EQ(DumpInstructions(b.capabilities()),
+ R"(OpCapability DerivativeControl
+)");
+ }
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpLoad %3 %1
+%6 = )" + param.op +
+ " %3 %7\n");
+}
+INSTANTIATE_TEST_SUITE_P(
+ BuilderTest,
+ IntrinsicDeriveTest,
+ testing::Values(IntrinsicData{"dpdx", "OpDPdx"},
+ IntrinsicData{"dpdx_fine", "OpDPdxFine"},
+ IntrinsicData{"dpdx_coarse", "OpDPdxCoarse"},
+ IntrinsicData{"dpdy", "OpDPdy"},
+ IntrinsicData{"dpdy_fine", "OpDPdyFine"},
+ IntrinsicData{"dpdy_coarse", "OpDPdyCoarse"},
+ IntrinsicData{"fwidth", "OpFwidth"},
+ IntrinsicData{"fwidth_fine", "OpFwidthFine"},
+ IntrinsicData{"fwidth_coarse", "OpFwidthCoarse"}));
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/src/writer/spirv/builder_test.cc b/src/writer/spirv/builder_test.cc
index f6f88a5..6c80d8f 100644
--- a/src/writer/spirv/builder_test.cc
+++ b/src/writer/spirv/builder_test.cc
@@ -36,7 +36,8 @@
Builder b(&m);
ASSERT_TRUE(b.Build());
- ASSERT_EQ(b.preamble().size(), 5u);
+ ASSERT_EQ(b.capabilities().size(), 2u);
+ ASSERT_EQ(b.preamble().size(), 3u);
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpCapability VulkanMemoryModel