spirv-reader: Track storage class for pointer/ref values

Fixed: tint:1041 tint:1648
Change-Id: I28c6677e0ef3f96902f4f9ced030c2280a17c247
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104762
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 4b7b3d3..e133e58 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -2557,6 +2557,8 @@
     }
     auto type_it = identifier_types_.find(id);
     if (type_it != identifier_types_.end()) {
+        // We have a local named definition: function parameter, let, or var
+        // declaration.
         auto name = namer_.Name(id);
         auto* type = type_it->second;
         return TypedExpression{
@@ -2585,10 +2587,13 @@
     switch (inst->opcode()) {
         case SpvOpVariable: {
             // This occurs for module-scope variables.
-            auto name = namer_.Name(inst->result_id());
-            return TypedExpression{
-                parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref),
-                create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
+            auto name = namer_.Name(id);
+            // Construct the reference type, mapping storage class correctly.
+            const auto* type =
+                RemapPointerProperties(parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref), id);
+            // TODO(crbug.com/tint/1041): Fix access mode
+            return TypedExpression{type, create<ast::IdentifierExpression>(
+                                             Source{}, builder_.Symbols().Register(name))};
         }
         case SpvOpUndef:
             // Substitute a null value for undef.
@@ -3356,11 +3361,13 @@
     for (auto id : sorted_by_index(block_info.hoisted_ids)) {
         const auto* def_inst = def_use_mgr_->GetDef(id);
         TINT_ASSERT(Reader, def_inst);
-        auto* storage_type = RemapAddressSpace(parser_impl_.ConvertType(def_inst->type_id()), id);
+        // Compute the store type.  Pointers are not storable, so there is
+        // no need to remap pointer properties.
+        auto* store_type = parser_impl_.ConvertType(def_inst->type_id());
         AddStatement(create<ast::VariableDeclStatement>(
-            Source{}, parser_impl_.MakeVar(id, ast::AddressSpace::kNone, storage_type, nullptr,
+            Source{}, parser_impl_.MakeVar(id, ast::AddressSpace::kNone, store_type, nullptr,
                                            AttributeList{})));
-        auto* type = ty_.Reference(storage_type, ast::AddressSpace::kNone);
+        auto* type = ty_.Reference(store_type, ast::AddressSpace::kNone);
         identifier_types_.emplace(id, type);
     }
 
@@ -3449,6 +3456,7 @@
     }
 
     expr = AddressOfIfNeeded(expr, &inst);
+    expr.type = RemapPointerProperties(expr.type, inst.result_id());
     auto* let = parser_impl_.MakeLet(inst.result_id(), expr.type, expr.expr);
     if (!let) {
         return false;
@@ -3720,7 +3728,6 @@
             if (!expr) {
                 return false;
             }
-            expr.type = RemapAddressSpace(expr.type, result_id);
             return EmitConstDefOrWriteToHoistedVar(inst, expr);
         }
 
@@ -3777,20 +3784,6 @@
     return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
 }
 
-TypedExpression FunctionEmitter::InferFunctionAddressSpace(TypedExpression expr) {
-    TypedExpression result(expr);
-    if (const auto* ref = expr.type->UnwrapAlias()->As<Reference>()) {
-        if (ref->address_space == ast::AddressSpace::kNone) {
-            expr.type = ty_.Reference(ref->type, ast::AddressSpace::kFunction);
-        }
-    } else if (const auto* ptr = expr.type->UnwrapAlias()->As<Pointer>()) {
-        if (ptr->address_space == ast::AddressSpace::kNone) {
-            expr.type = ty_.Pointer(ptr->type, ast::AddressSpace::kFunction);
-        }
-    }
-    return expr;
-}
-
 TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
     const spvtools::opt::Instruction& inst) {
     if (inst.result_id() == 0) {
@@ -4350,6 +4343,10 @@
     const auto num_in_operands = inst.NumInOperands();
 
     bool sink_pointer = false;
+    // The current WGSL expression for the pointer, starting with the base
+    // pointer and updated as each index is incorported.  The important part
+    // is the pointee (or "store type").  The address space and access mode will
+    // be patched as needed at the very end, via RemapPointerProperties.
     TypedExpression current_expr;
 
     // If the variable was originally gl_PerVertex, then in the AST we
@@ -4418,7 +4415,7 @@
     // ever-deeper nested indexing expressions. Start off with an expression
     // for the base, and then bury that inside nested indexing expressions.
     if (!current_expr) {
-        current_expr = InferFunctionAddressSpace(MakeOperand(inst, 0));
+        current_expr = MakeOperand(inst, 0);
         if (current_expr.type->Is<Pointer>()) {
             current_expr = Dereference(current_expr);
         }
@@ -4533,6 +4530,7 @@
         GetDefInfo(inst.result_id())->sink_pointer_source_expr = current_expr;
     }
 
+    current_expr.type = RemapPointerProperties(current_expr.type, inst.result_id());
     return current_expr;
 }
 
@@ -4799,32 +4797,27 @@
             ++index;
             auto& info = def_info_[result_id];
 
-            // Determine address space for pointer values. Do this in order because
-            // we might rely on the address space for a previously-visited definition.
-            // Logical pointers can't be transmitted through OpPhi, so remaining
-            // pointer definitions are SSA values, and their definitions must be
-            // visited before their uses.
             const auto* type = type_mgr_->GetType(inst.type_id());
             if (type) {
+                // Determine address space and access mode for pointer values. Do this in
+                // order because we might rely on the storage class for a previously-visited
+                // definition.
+                // Logical pointers can't be transmitted through OpPhi, so remaining
+                // pointer definitions are SSA values, and their definitions must be
+                // visited before their uses.
                 if (type->AsPointer()) {
-                    if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
-                        if (auto* ptr = ast_type->As<Pointer>()) {
-                            info->pointer.address_space = ptr->address_space;
-                        }
-                    }
                     switch (inst.opcode()) {
                         case SpvOpUndef:
                             return Fail() << "undef pointer is not valid: " << inst.PrettyPrint();
                         case SpvOpVariable:
-                            // Keep the default decision based on the result type.
+                            info->pointer = GetPointerInfo(result_id);
                             break;
                         case SpvOpAccessChain:
                         case SpvOpInBoundsAccessChain:
                         case SpvOpCopyObject:
                             // Inherit from the first operand. We need this so we can pick up
                             // a remapped storage buffer.
-                            info->pointer.address_space =
-                                GetAddressSpaceForPointerValue(inst.GetSingleWordInOperand(0));
+                            info->pointer = GetPointerInfo(inst.GetSingleWordInOperand(0));
                             break;
                         default:
                             return Fail() << "pointer defined in function from unknown opcode: "
@@ -4846,32 +4839,71 @@
     return true;
 }
 
-ast::AddressSpace FunctionEmitter::GetAddressSpaceForPointerValue(uint32_t id) {
+DefInfo::Pointer FunctionEmitter::GetPointerInfo(uint32_t id) {
+    // Compute the result from first principles, for a variable.
+    auto get_from_root_identifier =
+        [&](const spvtools::opt::Instruction& inst) -> DefInfo::Pointer {
+        // WGSL root identifiers (or SPIR-V "memory object declarations") are
+        // either variables or function parameters.
+        switch (inst.opcode()) {
+            case SpvOpVariable: {
+                if (const auto* module_var = parser_impl_.GetModuleVariable(id)) {
+                    return DefInfo::Pointer{module_var->declared_address_space,
+                                            module_var->declared_access};
+                }
+                // Local variables are always Function storage class, with default
+                // access mode.
+                return DefInfo::Pointer{ast::AddressSpace::kFunction, ast::Access::kUndefined};
+            }
+            case SpvOpFunctionParameter: {
+                const auto* type = As<Pointer>(parser_impl_.ConvertType(inst.type_id()));
+                // TODO(crbug.com/tint/1041): Add access mode.
+                // Using kUndefined is ok for now, since the only non-default access mode
+                // on a pointer would be for a storage buffer, and baseline SPIR-V doesn't
+                // allow passing pointers to buffers as function parameters.
+                return DefInfo::Pointer{type->address_space, ast::Access::kUndefined};
+            }
+            default:
+                break;
+        }
+        TINT_ASSERT(Reader, false && "expected a memory object declaration");
+        return {};
+    };
+
     auto where = def_info_.find(id);
     if (where != def_info_.end()) {
-        auto candidate = where->second.get()->pointer.address_space;
-        if (candidate != ast::AddressSpace::kInvalid) {
-            return candidate;
+        const auto& info = where->second;
+        if (info->inst.opcode() == SpvOpVariable) {
+            // Ignore the cache in this case and compute it from scratch.
+            // That's because for a function-scope OpVariable is a
+            // locally-defined value.  So its cache entry has been created
+            // with a default PointerInfo object, which has invalid data.
+            //
+            // Instead, you might think that we could forget this weirdness
+            // and instead have more standard cache-like behaviour. But then
+            // for non-function-scope variables we look up information
+            // from a saved ast::Var. But some builtins don't correspond
+            // to a declared ast::Var. This is simpler and more reliable.
+            return get_from_root_identifier(info->inst);
         }
+        // Use the cached value.
+        return info->pointer;
     }
-    const auto type_id = def_use_mgr_->GetDef(id)->type_id();
-    if (type_id) {
-        auto* ast_type = parser_impl_.ConvertType(type_id);
-        if (auto* ptr = As<Pointer>(ast_type)) {
-            return ptr->address_space;
-        }
-    }
-    return ast::AddressSpace::kInvalid;
+    const auto* inst = def_use_mgr_->GetDef(id);
+    TINT_ASSERT(Reader, inst);
+    return get_from_root_identifier(*inst);
 }
 
-const Type* FunctionEmitter::RemapAddressSpace(const Type* type, uint32_t result_id) {
+const Type* FunctionEmitter::RemapPointerProperties(const Type* type, uint32_t result_id) {
     if (auto* ast_ptr_type = As<Pointer>(type)) {
-        // Remap an old-style storage buffer pointer to a new-style storage
-        // buffer pointer.
-        const auto addr_space = GetAddressSpaceForPointerValue(result_id);
-        if (ast_ptr_type->address_space != addr_space) {
-            return ty_.Pointer(ast_ptr_type->type, addr_space);
-        }
+        const auto pi = GetPointerInfo(result_id);
+        // TODO(crbug.com/tin/t1041): also do access mode
+        return ty_.Pointer(ast_ptr_type->type, pi.address_space);
+    }
+    if (auto* ast_ptr_type = As<Reference>(type)) {
+        const auto pi = GetPointerInfo(result_id);
+        // TODO(crbug.com/tin/t1041): also do access mode
+        return ty_.Reference(ast_ptr_type->type, pi.address_space);
     }
     return type;
 }
diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h
index dc6002e..528799a 100644
--- a/src/tint/reader/spirv/function.h
+++ b/src/tint/reader/spirv/function.h
@@ -334,7 +334,8 @@
         /// This is kInvalid for non-pointers.
         ast::AddressSpace address_space = ast::AddressSpace::kInvalid;
 
-        // TODO(crbug.com/tint/1041) track access mode.
+        /// The declared access mode.
+        ast::Access access = ast::kUndefined;
     };
 
     /// The expression to use when sinking pointers into their use.
@@ -619,19 +620,23 @@
     /// @returns false on failure
     bool RegisterLocallyDefinedValues();
 
-    /// Returns the Tint address space for the given SPIR-V ID that is a
-    /// pointer value.
+    /// Returns the pointer information needed for the given SPIR-V ID.
+    /// Assumes the given ID yields a value of pointer type.  For IDs
+    /// corresponding to WGSL root identifiers (i.e. OpVariable or
+    /// OpFunctionParameter), the info is computed from scratch.
+    /// Otherwise, this looks up pointer info from a base pointer whose
+    /// data is cached in def_info_.
     /// @param id a SPIR-V ID for a pointer value
-    /// @returns the address space
-    ast::AddressSpace GetAddressSpaceForPointerValue(uint32_t id);
+    /// @returns the associated Pointer info
+    DefInfo::Pointer GetPointerInfo(uint32_t id);
 
-    /// Remaps the address space for the type of a locally-defined value,
-    /// if necessary. If it's not a pointer type, or if its address space
-    /// already matches, then the result is a copy of the `type` argument.
+    /// Remaps the address space and access mode for the type of a
+    /// locally-defined value, if necessary. If it's not a pointer or reference
+    /// type, then the result is a copy of the `type` argument.
     /// @param type the AST type
     /// @param result_id the SPIR-V ID for the locally defined value
     /// @returns an possibly updated type
-    const Type* RemapAddressSpace(const Type* type, uint32_t result_id);
+    const Type* RemapPointerProperties(const Type* type, uint32_t result_id);
 
     /// Marks locally defined values when they should get a 'let'
     /// definition in WGSL, or a 'var' definition at an outer scope.
@@ -1011,13 +1016,6 @@
     /// @returns a new expression node
     TypedExpression MakeOperand(const spvtools::opt::Instruction& inst, uint32_t operand_index);
 
-    /// Copies a typed expression to the result, but when the type is a pointer
-    /// or reference type, ensures the address space is not defaulted.  That is,
-    /// it changes a address space of "none" to "function".
-    /// @param expr a typed expression
-    /// @results a copy of the expression, with possibly updated type
-    TypedExpression InferFunctionAddressSpace(TypedExpression expr);
-
     /// Returns an expression for a SPIR-V OpFMod instruction.
     /// @param inst the SPIR-V instruction
     /// @returns an expression
diff --git a/src/tint/reader/spirv/function_memory_test.cc b/src/tint/reader/spirv/function_memory_test.cc
index dbf9219..a64aac5 100644
--- a/src/tint/reader/spirv/function_memory_test.cc
+++ b/src/tint/reader/spirv/function_memory_test.cc
@@ -938,6 +938,43 @@
 )"));
 }
 
+TEST_F(SpvParserMemoryTest, RemapStorageBuffer_ThroughAccessChain_NonCascaded_UsedTwice) {
+    // Use the pointer value twice, which provokes the spirv-reader
+    // to make a let declaration for the pointer.  The storage class
+    // must be 'storage', not 'uniform'.
+    const auto assembly = OldStorageBufferPreamble() + R"(
+  %100 = OpFunction %void None %voidfn
+  %entry = OpLabel
+
+  ; the scalar element
+  %1 = OpAccessChain %ptr_uint %myvar %uint_0
+  OpStore %1 %uint_0
+  OpStore %1 %uint_0
+
+  ; element in the runtime array
+  %2 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1
+  ; Use the pointer twice
+  %3 = OpLoad %uint %2
+  OpStore %2 %uint_0
+
+  OpReturn
+  OpFunctionEnd
+)";
+    auto p = parser(test::Assemble(assembly));
+    ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
+    auto fe = p->function_emitter(100);
+    EXPECT_TRUE(fe.EmitBody()) << p->error();
+    auto ast_body = fe.ast_body();
+    const auto got = test::ToString(p->program(), ast_body);
+    EXPECT_THAT(got, HasSubstr(R"(let x_1 : ptr<storage, u32> = &(myvar.field0);
+*(x_1) = 0u;
+*(x_1) = 0u;
+let x_2 : ptr<storage, u32> = &(myvar.field1[1u]);
+let x_3 : u32 = *(x_2);
+*(x_2) = 0u;
+)"));
+}
+
 TEST_F(SpvParserMemoryTest, RemapStorageBuffer_ThroughAccessChain_NonCascaded_InBoundsAccessChain) {
     // Like the previous test, but using OpInBoundsAccessChain.
     const auto assembly = OldStorageBufferPreamble() + R"(
@@ -1020,56 +1057,6 @@
     p->SkipDumpingPending("crbug.com/tint/1041 track access mode in spirv-reader parser type");
 }
 
-TEST_F(SpvParserMemoryTest, RemapStorageBuffer_ThroughCopyObject_WithHoisting) {
-    // TODO(dneto): Hoisting non-storable values (pointers) is not yet supported.
-    // It's debatable whether this test should run at all.
-    // crbug.com/tint/98
-
-    // Like the previous test, but the declaration for the copy-object
-    // has its declaration hoisted.
-    const auto assembly = OldStorageBufferPreamble() + R"(
-  %bool = OpTypeBool
-  %cond = OpConstantTrue %bool
-
-  %100 = OpFunction %void None %voidfn
-
-  %entry = OpLabel
-  OpSelectionMerge %99 None
-  OpBranchConditional %cond %20 %30
-
-  %20 = OpLabel
-  %1 = OpAccessChain %ptr_uint %myvar %uint_1 %uint_1
-  ; this definintion dominates the use in %99
-  %2 = OpCopyObject %ptr_uint %1
-  OpBranch %99
-
-  %30 = OpLabel
-  OpReturn
-
-  %99 = OpLabel
-  OpStore %2 %uint_0
-  OpReturn
-
-  OpFunctionEnd
-)";
-    auto p = parser(test::Assemble(assembly));
-    ASSERT_TRUE(p->BuildAndParseInternalModule()) << assembly << p->error();
-    auto fe = p->function_emitter(100);
-    EXPECT_TRUE(fe.EmitBody()) << p->error();
-    auto ast_body = fe.ast_body();
-    EXPECT_EQ(test::ToString(p->program(), ast_body),
-              R"(var x_2 : ptr<storage, u32>;
-if (true) {
-  x_2 = &(myvar.field1[1u]);
-} else {
-  return;
-}
-x_2 = 0u;
-return;
-)") << p->error();
-    p->SkipDumpingPending("crbug.com/tint/98");
-}
-
 std::string RuntimeArrayPreamble() {
     return R"(
      OpCapability Shader