[spirv-reader] Handle StorageBuffer storage class
Use the presence of the `NonWritable` decoration to select between
`read` and `read_write` access modes. Propagate the access mode
through `access` instructions.
Check for unsupported extensions; allow the
SPV_KHR_storage_buffer_storage_class extension (and nothing else, for
now).
Bug: tint:1907
Change-Id: Ia288373ff079d49449a47a13820eb08e4ae40380
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/170480
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/memory_test.cc b/src/tint/lang/spirv/reader/parser/memory_test.cc
index 9f8c75f..acb3972 100644
--- a/src/tint/lang/spirv/reader/parser/memory_test.cc
+++ b/src/tint/lang/spirv/reader/parser/memory_test.cc
@@ -760,5 +760,59 @@
)");
}
+TEST_F(SpirvParserTest, StorageBufferAccessMode) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %str Block
+ OpMemberDecorate %str 0 Offset 0
+ OpDecorate %ro_var NonWritable
+ OpDecorate %ro_var DescriptorSet 1
+ OpDecorate %ro_var Binding 2
+ OpDecorate %rw_var DescriptorSet 1
+ OpDecorate %rw_var Binding 3
+ %void = OpTypeVoid
+ %u32 = OpTypeInt 32 0
+ %str = OpTypeStruct %u32
+ %u32_ptr = OpTypePointer StorageBuffer %u32
+ %str_ptr = OpTypePointer StorageBuffer %str
+ %ep_type = OpTypeFunction %void
+ %u32_0 = OpConstant %u32 0
+ %ro_var = OpVariable %str_ptr StorageBuffer
+ %rw_var = OpVariable %str_ptr StorageBuffer
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ %ro_access = OpAccessChain %u32_ptr %ro_var %u32_0
+ %rw_access = OpAccessChain %u32_ptr %rw_var %u32_0
+ %load = OpLoad %u32 %ro_access
+ OpStore %rw_access %load
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+tint_symbol_1 = struct @align(4) {
+ tint_symbol:u32 @offset(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, tint_symbol_1, read> = var @binding_point(1, 2)
+ %2:ptr<storage, tint_symbol_1, read_write> = var @binding_point(1, 3)
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %4:ptr<storage, u32, read> = access %1, 0u
+ %5:ptr<storage, u32, read_write> = access %2, 0u
+ %6:u32 = load %4
+ store %5, %6
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 6fe4eb9..431ae28 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -76,6 +76,14 @@
return Failure("failed to build the internal representation of the module");
}
+ // Check for unsupported extensions.
+ for (const auto& ext : spirv_context_->extensions()) {
+ auto name = ext.GetOperand(0).AsString();
+ if (name != "SPV_KHR_storage_buffer_storage_class") {
+ return Failure("SPIR-V extension '" + name + "' is not supported");
+ }
+ }
+
{
TINT_SCOPED_ASSIGNMENT(current_block_, ir_.root_block);
EmitModuleScopeVariables();
@@ -99,6 +107,8 @@
return core::AddressSpace::kFunction;
case spv::StorageClass::Private:
return core::AddressSpace::kPrivate;
+ case spv::StorageClass::StorageBuffer:
+ return core::AddressSpace::kStorage;
case spv::StorageClass::Uniform:
return core::AddressSpace::kUniform;
default:
@@ -109,9 +119,11 @@
}
/// @param type a SPIR-V type object
+ /// @param access_mode an optional access mode (for pointers)
/// @returns a Tint type object
- const core::type::Type* Type(const spvtools::opt::analysis::Type* type) {
- return types_.GetOrCreate(type, [&]() -> const core::type::Type* {
+ const core::type::Type* Type(const spvtools::opt::analysis::Type* type,
+ core::Access access_mode = core::Access::kUndefined) {
+ return types_.GetOrCreate(TypeKey{type, access_mode}, [&]() -> const core::type::Type* {
switch (type->kind()) {
case spvtools::opt::analysis::Type::kVoid:
return ty_.void_();
@@ -156,7 +168,7 @@
case spvtools::opt::analysis::Type::kPointer: {
auto* ptr_ty = type->AsPointer();
return ty_.ptr(AddressSpace(ptr_ty->storage_class()),
- Type(ptr_ty->pointee_type()));
+ Type(ptr_ty->pointee_type()), access_mode);
}
default:
TINT_UNIMPLEMENTED() << "unhandled SPIR-V type: " << type->str();
@@ -166,9 +178,10 @@
}
/// @param id a SPIR-V result ID for a type declaration instruction
+ /// @param access_mode an optional access mode (for pointers)
/// @returns a Tint type object
- const core::type::Type* Type(uint32_t id) {
- return Type(spirv_context_->get_type_mgr()->GetType(id));
+ const core::type::Type* Type(uint32_t id, core::Access access_mode = core::Access::kUndefined) {
+ return Type(spirv_context_->get_type_mgr()->GetType(id), access_mode);
}
/// @param arr_ty a SPIR-V array object
@@ -459,7 +472,14 @@
indices.Push(Value(inst.GetSingleWordOperand(i)));
}
auto* base = Value(inst.GetSingleWordOperand(2));
- auto* access = b_.Access(Type(inst.type_id()), base, std::move(indices));
+
+ // Propagate the access mode of the base object.
+ auto access_mode = core::Access::kUndefined;
+ if (auto* ptr = base->Type()->As<core::type::Pointer>()) {
+ access_mode = ptr->Access();
+ }
+
+ auto* access = b_.Access(Type(inst.type_id(), access_mode), base, std::move(indices));
Emit(access, inst.result_id());
}
@@ -496,19 +516,16 @@
/// @param inst the SPIR-V instruction for OpVariable
void EmitVar(const spvtools::opt::Instruction& inst) {
- auto* var = b_.Var(Type(inst.type_id())->As<core::type::Pointer>());
- if (inst.NumOperands() > 3) {
- var->SetInitializer(Value(inst.GetSingleWordOperand(3)));
- }
-
// Handle decorations.
std::optional<uint32_t> group;
std::optional<uint32_t> binding;
+ core::Access access_mode = core::Access::kUndefined;
for (auto* deco :
spirv_context_->get_decoration_mgr()->GetDecorationsFor(inst.result_id(), false)) {
auto d = deco->GetSingleWordOperand(1);
switch (spv::Decoration(d)) {
case spv::Decoration::NonWritable:
+ access_mode = core::Access::kRead;
break;
case spv::Decoration::DescriptorSet:
group = deco->GetSingleWordOperand(2);
@@ -521,6 +538,12 @@
break;
}
}
+
+ auto* var = b_.Var(Type(inst.type_id(), access_mode)->As<core::type::Pointer>());
+ if (inst.NumOperands() > 3) {
+ var->SetInitializer(Value(inst.GetSingleWordOperand(3)));
+ }
+
if (group || binding) {
TINT_ASSERT(group && binding);
var->SetBindingPoint(group.value(), binding.value());
@@ -530,6 +553,28 @@
}
private:
+ /// TypeKey describes a SPIR-V type with an access mode.
+ struct TypeKey {
+ /// The SPIR-V type object.
+ const spvtools::opt::analysis::Type* type;
+ /// The access mode.
+ core::Access access_mode;
+
+ // Equality operator for TypeKey.
+ bool operator==(const TypeKey& other) const {
+ return type == other.type && access_mode == other.access_mode;
+ }
+
+ /// Hasher provides a hash function for the TypeKey.
+ struct Hasher {
+ /// @param tk the TypeKey to create a hash for
+ /// @return the hash value
+ inline std::size_t operator()(const TypeKey& tk) const {
+ return HashCombine(Hash(tk.type), tk.access_mode);
+ }
+ };
+ };
+
/// The generated IR module.
core::ir::Module ir_;
/// The Tint IR builder.
@@ -541,8 +586,8 @@
core::ir::Function* current_function_ = nullptr;
/// The Tint IR block that is currently being emitted.
core::ir::Block* current_block_ = nullptr;
- /// A map from a SPIR-V type declaration result ID to the corresponding Tint type object.
- Hashmap<const spvtools::opt::analysis::Type*, const core::type::Type*, 16> types_;
+ /// A map from a SPIR-V type declaration to the corresponding Tint type object.
+ Hashmap<TypeKey, const core::type::Type*, 16, TypeKey::Hasher> types_;
/// A map from a SPIR-V function definition result ID to the corresponding Tint function object.
Hashmap<uint32_t, core::ir::Function*, 8> functions_;
/// A map from a SPIR-V result ID to the corresponding Tint value object.
diff --git a/src/tint/lang/spirv/reader/parser/var_test.cc b/src/tint/lang/spirv/reader/parser/var_test.cc
index 88a6347..dfded72 100644
--- a/src/tint/lang/spirv/reader/parser/var_test.cc
+++ b/src/tint/lang/spirv/reader/parser/var_test.cc
@@ -142,6 +142,129 @@
)");
}
+TEST_F(SpirvParserTest, StorageVar_ReadOnly) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ OpDecorate %str Block
+ OpMemberDecorate %str 0 Offset 0
+ OpDecorate %6 NonWritable
+ OpDecorate %6 DescriptorSet 1
+ OpDecorate %6 Binding 2
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %str = OpTypeStruct %uint
+%_ptr_StorageBuffer_str = OpTypePointer StorageBuffer %str
+ %5 = OpTypeFunction %void
+ %6 = OpVariable %_ptr_StorageBuffer_str StorageBuffer
+ %1 = OpFunction %void None %5
+ %7 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+tint_symbol_1 = struct @align(4) {
+ tint_symbol:u32 @offset(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, tint_symbol_1, read> = var @binding_point(1, 2)
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, StorageVar_ReadWrite) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ OpDecorate %str Block
+ OpMemberDecorate %str 0 Offset 0
+ OpDecorate %6 DescriptorSet 1
+ OpDecorate %6 Binding 2
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %str = OpTypeStruct %uint
+%_ptr_StorageBuffer_str = OpTypePointer StorageBuffer %str
+ %5 = OpTypeFunction %void
+ %6 = OpVariable %_ptr_StorageBuffer_str StorageBuffer
+ %1 = OpFunction %void None %5
+ %7 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+tint_symbol_1 = struct @align(4) {
+ tint_symbol:u32 @offset(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, tint_symbol_1, read_write> = var @binding_point(1, 2)
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, StorageVar_ReadOnly_And_ReadWrite) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ OpExecutionMode %1 LocalSize 1 1 1
+ OpDecorate %str Block
+ OpMemberDecorate %str 0 Offset 0
+ OpDecorate %6 NonWritable
+ OpDecorate %6 DescriptorSet 1
+ OpDecorate %6 Binding 2
+ OpDecorate %7 DescriptorSet 1
+ OpDecorate %7 Binding 3
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %str = OpTypeStruct %uint
+%_ptr_StorageBuffer_str = OpTypePointer StorageBuffer %str
+ %5 = OpTypeFunction %void
+ %6 = OpVariable %_ptr_StorageBuffer_str StorageBuffer
+ %7 = OpVariable %_ptr_StorageBuffer_str StorageBuffer
+ %1 = OpFunction %void None %5
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+tint_symbol_1 = struct @align(4) {
+ tint_symbol:u32 @offset(0)
+}
+
+%b1 = block { # root
+ %1:ptr<storage, tint_symbol_1, read> = var @binding_point(1, 2)
+ %2:ptr<storage, tint_symbol_1, read_write> = var @binding_point(1, 3)
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ ret
+ }
+}
+)");
+}
+
TEST_F(SpirvParserTest, UniformVar) {
EXPECT_IR(R"(
OpCapability Shader
diff --git a/src/tint/lang/spirv/reader/reader_test.cc b/src/tint/lang/spirv/reader/reader_test.cc
index 1a086a8..5313aaf 100644
--- a/src/tint/lang/spirv/reader/reader_test.cc
+++ b/src/tint/lang/spirv/reader/reader_test.cc
@@ -68,6 +68,25 @@
}
};
+TEST_F(SpirvReaderTest, UnsupportedExtension) {
+ auto got = Run(R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %ep_type = OpTypeFunction %void
+ %main = OpFunction %void None %ep_type
+ %main_start = OpLabel
+ OpReturn
+ OpFunctionEnd
+)");
+ ASSERT_NE(got, Success);
+ EXPECT_EQ(got.Failure().reason.str(),
+ "error: SPIR-V extension 'SPV_KHR_variable_pointers' is not supported");
+}
+
TEST_F(SpirvReaderTest, Load_VectorComponent) {
auto got = Run(R"(
OpCapability Shader