[spirv-writer] Add support for sample index/mask builtins
Use a sanitizing transform to convert scalar `sample_mask_{in,out}`
variables to single element arrays.
Add the `SampleRateShading` capability if the `sample_index` builtin
is used.
Bug: tint:372
Change-Id: Id7280e3ddb21e0a098d83587d123c97e3c34fa1b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41662
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 03c2274..0d68e41 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -355,8 +355,6 @@
"src/clone_context.h",
"src/demangler.cc",
"src/demangler.h",
- "src/intrinsic_table.cc",
- "src/intrinsic_table.h",
"src/diagnostic/diagnostic.cc",
"src/diagnostic/diagnostic.h",
"src/diagnostic/formatter.cc",
@@ -369,6 +367,8 @@
"src/inspector/inspector.h",
"src/inspector/scalar.cc",
"src/inspector/scalar.h",
+ "src/intrinsic_table.cc",
+ "src/intrinsic_table.h",
"src/namer.cc",
"src/namer.h",
"src/program.cc",
@@ -953,6 +953,7 @@
source_set("tint_unittests_spv_writer_src") {
sources = [
+ "src/transform/spirv_test.cc",
"src/writer/spirv/binary_writer_test.cc",
"src/writer/spirv/builder_accessor_expression_test.cc",
"src/writer/spirv/builder_assign_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index a42d8e3..d1648b9 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -616,6 +616,7 @@
if(${TINT_BUILD_SPV_WRITER})
list(APPEND TINT_TEST_SRCS
+ transform/spirv_test.cc
writer/spirv/binary_writer_test.cc
writer/spirv/builder_accessor_expression_test.cc
writer/spirv/builder_assign_test.cc
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 07797d9..0f3d7fb 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -14,9 +14,13 @@
#include "src/transform/spirv.h"
+#include <string>
#include <utility>
+#include "src/ast/variable.h"
#include "src/program_builder.h"
+#include "src/semantic/expression.h"
+#include "src/semantic/variable.h"
namespace tint {
namespace transform {
@@ -26,9 +30,60 @@
Transform::Output Spirv::Run(const Program* in) {
ProgramBuilder out;
- CloneContext(&out, in).Clone();
+ CloneContext ctx(&out, in);
+ HandleSampleMaskBuiltins(ctx);
+ ctx.Clone();
return Output{Program(std::move(out))};
}
+void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
+ // Find global variables decorated with [[builtin(sample_mask_{in,out})]] and
+ // change their type from `u32` to `array<u32, 1>`, as required by Vulkan.
+ //
+ // Before:
+ // ```
+ // [[builtin(sample_mask_out)]] var<out> mask_out : u32;
+ // fn main() -> void {
+ // mask_out = 1u;
+ // }
+ // ```
+ // After:
+ // ```
+ // [[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+ // fn main() -> void {
+ // mask_out[0] = 1u;
+ // }
+ // ```
+
+ for (auto* var : ctx.src->AST().GlobalVariables()) {
+ for (auto* deco : var->decorations()) {
+ if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
+ if (builtin->value() != ast::Builtin::kSampleMaskIn &&
+ builtin->value() != ast::Builtin::kSampleMaskOut) {
+ continue;
+ }
+
+ // Use the same name as the old variable.
+ std::string var_name = ctx.src->Symbols().NameFor(var->symbol());
+ // Use `array<u32, 1>` for the new variable.
+ auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
+ // Create the new variable.
+ auto* var_arr =
+ ctx.dst->Var(var->source(), var_name, var->declared_storage_class(),
+ type, nullptr, ctx.Clone(var->decorations()));
+ // Replace the variable with the arrayed version.
+ ctx.Replace(var, var_arr);
+
+ // Replace all uses of the old variable with `var_arr[0]`.
+ for (auto* user : ctx.src->Sem().Get(var)->Users()) {
+ auto* new_ident = ctx.dst->IndexAccessor(
+ ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0));
+ ctx.Replace<ast::Expression>(user->Declaration(), new_ident);
+ }
+ }
+ }
+ }
+}
+
} // namespace transform
} // namespace tint
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index 3644445..5f79d81 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -18,6 +18,10 @@
#include "src/transform/transform.h"
namespace tint {
+
+// Forward declarations
+class CloneContext;
+
namespace transform {
/// Spirv is a transform used to sanitize a Program for use with the Spirv
@@ -36,6 +40,10 @@
/// @param program the source program to transform
/// @returns the transformation result
Output Run(const Program* program) override;
+
+ private:
+ /// Change type of sample mask builtin variables to single element arrays.
+ void HandleSampleMaskBuiltins(CloneContext& ctx) const;
};
} // namespace transform
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
new file mode 100644
index 0000000..53bdc52
--- /dev/null
+++ b/src/transform/spirv_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/spirv.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using SpirvTest = TransformTest;
+
+TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
+ auto* src = R"(
+[[builtin(sample_index)]] var<in> sample_index : u32;
+
+[[builtin(sample_mask_in)]] var<in> mask_in : u32;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : u32;
+
+[[stage(fragment)]]
+fn main() -> void {
+ mask_out = mask_in;
+}
+)";
+
+ auto* expect = R"(
+[[builtin(sample_index)]] var<in> sample_index : u32;
+
+[[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+
+[[stage(fragment)]]
+fn main() -> void {
+ mask_out[0] = mask_in[0];
+}
+)";
+
+ auto got = Transform<Spirv>(src);
+
+ EXPECT_EQ(expect, got);
+}
+
+TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
+ auto* src = R"(
+[[builtin(sample_mask_in)]] var<in> mask_in : u32;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : u32;
+
+fn filter(mask: u32) -> u32 {
+ return (mask & 3u);
+}
+
+fn set_mask(input : u32) -> void {
+ mask_out = input;
+}
+
+[[stage(fragment)]]
+fn main() -> void {
+ set_mask(filter(mask_in));
+}
+)";
+
+ auto* expect = R"(
+[[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+
+fn filter(mask : u32) -> u32 {
+ return (mask & 3u);
+}
+
+fn set_mask(input : u32) -> void {
+ mask_out[0] = input;
+}
+
+[[stage(fragment)]]
+fn main() -> void {
+ set_mask(filter(mask_in[0]));
+}
+)";
+
+ auto got = Transform<Spirv>(src);
+
+ EXPECT_EQ(expect, got);
+}
+
+} // namespace
+} // namespace transform
+} // namespace tint
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 6b545e3..bd75938 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -3172,7 +3172,7 @@
return SpvStorageClassMax;
}
-SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin) const {
+SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin) {
switch (builtin) {
case ast::Builtin::kPosition:
return SpvBuiltInPosition;
@@ -3194,9 +3194,13 @@
return SpvBuiltInGlobalInvocationId;
case ast::Builtin::kPointSize:
return SpvBuiltInPointSize;
- case ast::Builtin::kSampleIndex: // TODO(dneto)
- case ast::Builtin::kSampleMaskIn: // TODO(dneto)
- case ast::Builtin::kSampleMaskOut: // TODO(dneto)
+ case ast::Builtin::kSampleIndex:
+ push_capability(SpvCapabilitySampleRateShading);
+ return SpvBuiltInSampleId;
+ case ast::Builtin::kSampleMaskIn:
+ return SpvBuiltInSampleMask;
+ case ast::Builtin::kSampleMaskOut:
+ return SpvBuiltInSampleMask;
case ast::Builtin::kNone:
break;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index b1d38b3..bba8bd5 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -216,10 +216,10 @@
/// @param klass the storage class to convert
/// @returns the SPIR-V storage class or SpvStorageClassMax on error.
SpvStorageClass ConvertStorageClass(ast::StorageClass klass) const;
- /// Converts a builtin to a SPIR-V builtin
+ /// Converts a builtin to a SPIR-V builtin and pushes a capability if needed.
/// @param builtin the builtin to convert
/// @returns the SPIR-V builtin or SpvBuiltInMax on error.
- SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const;
+ SpvBuiltIn ConvertBuiltin(ast::Builtin builtin);
/// Generates a label for the given id. Emits an error and returns false if
/// we're currently outside a function.
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 4a9517b..c39c281 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -24,6 +24,7 @@
#include "src/ast/group_decoration.h"
#include "src/ast/location_decoration.h"
#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/stage_decoration.h"
#include "src/ast/storage_class.h"
#include "src/ast/struct.h"
#include "src/ast/type_constructor_expression.h"
@@ -401,7 +402,10 @@
BuiltinData{ast::Builtin::kLocalInvocationIndex,
SpvBuiltInLocalInvocationIndex},
BuiltinData{ast::Builtin::kGlobalInvocationId,
- SpvBuiltInGlobalInvocationId}));
+ SpvBuiltInGlobalInvocationId},
+ BuiltinData{ast::Builtin::kSampleIndex, SpvBuiltInSampleId},
+ BuiltinData{ast::Builtin::kSampleMaskIn, SpvBuiltInSampleMask},
+ BuiltinData{ast::Builtin::kSampleMaskOut, SpvBuiltInSampleMask}));
TEST_F(BuilderTest, GlobalVar_DeclReadOnly) {
// struct A {
@@ -629,6 +633,95 @@
)");
}
+TEST_F(BuilderTest, SampleIndex) {
+ auto* var =
+ Global("sample_index", ast::StorageClass::kInput, ty.u32(), nullptr,
+ ast::VariableDecorationList{
+ create<ast::BuiltinDecoration>(ast::Builtin::kSampleIndex),
+ });
+
+ spirv::Builder& b = Build();
+
+ EXPECT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.capabilities()),
+ "OpCapability SampleRateShading\n");
+ EXPECT_EQ(DumpInstructions(b.annots()), "OpDecorate %1 BuiltIn SampleId\n");
+ EXPECT_EQ(DumpInstructions(b.types()),
+ "%3 = OpTypeInt 32 0\n"
+ "%2 = OpTypePointer Input %3\n"
+ "%1 = OpVariable %2 Input\n");
+}
+
+TEST_F(BuilderTest, SampleMask) {
+ // Input:
+ // [[builtin(sample_mask_in)]] var<in> mask_in : u32;
+ // [[builtin(sample_mask_out)]] var<out> mask_out : u32;
+ // [[stage(fragment)]]
+ // fn main() -> void {
+ // mask_out = mask_in;
+ // }
+
+ // After sanitization:
+ // [[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+ // [[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+ // [[stage(fragment)]]
+ // fn main() -> void {
+ // mask_out[0] = mask_in[0];
+ // }
+
+ Global("mask_in", ast::StorageClass::kInput, ty.u32(), nullptr,
+ ast::VariableDecorationList{
+ create<ast::BuiltinDecoration>(ast::Builtin::kSampleMaskIn),
+ });
+ Global("mask_out", ast::StorageClass::kOutput, ty.u32(), nullptr,
+ ast::VariableDecorationList{
+ create<ast::BuiltinDecoration>(ast::Builtin::kSampleMaskOut),
+ });
+ Func("main", ast::VariableList{}, ty.void_(),
+ ast::StatementList{
+ create<ast::AssignmentStatement>(Expr("mask_out"), Expr("mask_in")),
+ },
+ ast::FunctionDecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ });
+
+ spirv::Builder& b = SanitizeAndBuild();
+
+ ASSERT_TRUE(b.Build());
+ EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %11 "main" %6 %1
+OpExecutionMode %11 LocalSize 1 1 1
+OpName %1 "mask_in"
+OpName %6 "mask_out"
+OpName %11 "main"
+OpDecorate %1 BuiltIn SampleMask
+OpDecorate %6 BuiltIn SampleMask
+%4 = OpTypeInt 32 0
+%5 = OpConstant %4 1
+%3 = OpTypeArray %4 %5
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%7 = OpTypePointer Output %3
+%8 = OpConstantNull %3
+%6 = OpVariable %7 Output %8
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
+%13 = OpTypeInt 32 1
+%14 = OpConstant %13 0
+%15 = OpTypePointer Output %4
+%17 = OpTypePointer Input %4
+%11 = OpFunction %10 None %9
+%12 = OpLabel
+%16 = OpAccessChain %15 %6 %14
+%18 = OpAccessChain %17 %1 %14
+%19 = OpLoad %4 %18
+OpStore %16 %19
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/src/writer/spirv/test_helper.h b/src/writer/spirv/test_helper.h
index aafc15d..ba5e51a 100644
--- a/src/writer/spirv/test_helper.h
+++ b/src/writer/spirv/test_helper.h
@@ -24,6 +24,7 @@
#include "src/ast/module.h"
#include "src/diagnostic/formatter.h"
#include "src/program_builder.h"
+#include "src/transform/spirv.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/binary_writer.h"
#include "src/writer/spirv/builder.h"
@@ -60,6 +61,34 @@
return *spirv_builder;
}
+ /// Builds the program, runs the program through the transform::Spirv
+ /// sanitizer and returns a spirv::Builder from the sanitized program.
+ /// @note The spirv::Builder is only built once. Multiple calls to Build()
+ /// will return the same spirv::Builder without rebuilding.
+ /// @return the built spirv::Builder
+ spirv::Builder& SanitizeAndBuild() {
+ if (spirv_builder) {
+ return *spirv_builder;
+ }
+ [&]() {
+ ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
+ << diag::Formatter().format(Diagnostics());
+ }();
+ program = std::make_unique<Program>(std::move(*this));
+ [&]() {
+ ASSERT_TRUE(program->IsValid())
+ << diag::Formatter().format(program->Diagnostics());
+ }();
+ auto result = transform::Spirv().Run(program.get());
+ [&]() {
+ ASSERT_FALSE(result.diagnostics.contains_errors())
+ << diag::Formatter().format(result.diagnostics);
+ }();
+ *program = std::move(result.program);
+ spirv_builder = std::make_unique<spirv::Builder>(program.get());
+ return *spirv_builder;
+ }
+
/// Validate passes the generated SPIR-V of the builder `b` to the SPIR-V
/// Tools Validator. If the validator finds problems the test will fail.
/// @param b the spirv::Builder containing the built SPIR-V module