Resolver: Track storage class usages of structures

This will be used to validate layout rules, as well as preventing
illegal types from being used in a uniform / storage buffer.

Also: Cleanup logic around VariableDeclStatement
This was spread across 3 places, entirely unnecessarily.

Bug: tint:643
Change-Id: I9d309c3a5dfb5676984f49ce51763a97bcac93bb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45125
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 53a69f9..6e16b2e 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -151,6 +151,11 @@
         return false;
       }
     }
+
+    if (!ApplyStorageClassUsageToType(var->declared_storage_class(),
+                                      var->type())) {
+      return false;
+    }
   }
 
   if (!Functions(builder_->AST().Functions())) {
@@ -200,16 +205,6 @@
 
 bool Resolver::Statements(const ast::StatementList& stmts) {
   for (auto* stmt : stmts) {
-    if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
-      if (!VariableDeclStatement(decl)) {
-        return false;
-      }
-    }
-
-    if (!VariableStorageClass(stmt)) {
-      return false;
-    }
-
     if (!Statement(stmt)) {
       return false;
     }
@@ -217,36 +212,6 @@
   return true;
 }
 
-bool Resolver::VariableStorageClass(ast::Statement* stmt) {
-  auto* var_decl = stmt->As<ast::VariableDeclStatement>();
-  if (var_decl == nullptr) {
-    return true;
-  }
-
-  auto* var = var_decl->variable();
-
-  auto* info = CreateVariableInfo(var);
-  variable_to_info_.emplace(var, info);
-
-  // Nothing to do for const
-  if (var->is_const()) {
-    return true;
-  }
-
-  if (info->storage_class == ast::StorageClass::kFunction) {
-    return true;
-  }
-
-  if (info->storage_class != ast::StorageClass::kNone) {
-    diagnostics_.add_error("function variable has a non-function storage class",
-                           stmt->source());
-    return false;
-  }
-
-  info->storage_class = ast::StorageClass::kFunction;
-  return true;
-}
-
 bool Resolver::Statement(ast::Statement* stmt) {
   auto* sem_statement = builder_->create<semantic::Statement>(stmt);
 
@@ -336,10 +301,7 @@
     return true;
   }
   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    variable_stack_.set(v->variable()->symbol(),
-                        variable_to_info_.at(v->variable()));
-    current_block_->decls.push_back(v->variable());
-    return Expression(v->variable()->constructor());
+    return VariableDeclStatement(v);
   }
 
   diagnostics_.add_error(
@@ -1118,21 +1080,44 @@
 }
 
 bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
-  auto* ctor = stmt->variable()->constructor();
-  if (!ctor) {
-    return true;
-  }
-
-  if (auto* sce = ctor->As<ast::ScalarConstructorExpression>()) {
-    auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded();
-    auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded();
-
-    if (lhs_type != rhs_type) {
-      diagnostics_.add_error(
-          "constructor expression type does not match variable type",
-          stmt->source());
+  if (auto* ctor = stmt->variable()->constructor()) {
+    if (!Expression(ctor)) {
       return false;
     }
+    if (auto* sce = ctor->As<ast::ScalarConstructorExpression>()) {
+      auto* lhs_type = stmt->variable()->type()->UnwrapAliasIfNeeded();
+      auto* rhs_type = sce->literal()->type()->UnwrapAliasIfNeeded();
+
+      if (lhs_type != rhs_type) {
+        diagnostics_.add_error(
+            "constructor expression type does not match variable type",
+            stmt->source());
+        return false;
+      }
+    }
+  }
+
+  auto* var = stmt->variable();
+
+  auto* info = CreateVariableInfo(var);
+  variable_to_info_.emplace(var, info);
+  variable_stack_.set(var->symbol(), info);
+  current_block_->decls.push_back(var);
+
+  if (!var->is_const()) {
+    if (info->storage_class != ast::StorageClass::kFunction) {
+      if (info->storage_class != ast::StorageClass::kNone) {
+        diagnostics_.add_error(
+            "function variable has a non-function storage class",
+            stmt->source());
+        return false;
+      }
+      info->storage_class = ast::StorageClass::kFunction;
+    }
+  }
+
+  if (!ApplyStorageClassUsageToType(info->storage_class, var->type())) {
+    return false;
   }
 
   return true;
@@ -1247,9 +1232,10 @@
   for (auto it : struct_info_) {
     auto* str = it.first;
     auto* info = it.second;
-    builder_->Sem().Add(str, builder_->create<semantic::Struct>(
-                                 str, std::move(info->members), info->align,
-                                 info->size, info->size_no_padding));
+    builder_->Sem().Add(
+        str, builder_->create<semantic::Struct>(
+                 str, std::move(info->members), info->align, info->size,
+                 info->size_no_padding, info->storage_class_usage));
   }
 }
 
@@ -1470,6 +1456,44 @@
   return info;
 }
 
+bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
+                                            type::Type* ty) {
+  ty = ty->UnwrapAliasIfNeeded();
+
+  if (auto* str = ty->As<type::Struct>()) {
+    auto* info = Structure(str);
+    if (!info) {
+      return false;
+    }
+    if (info->storage_class_usage.count(sc)) {
+      return true;  // Already applied
+    }
+    info->storage_class_usage.emplace(sc);
+    for (auto* member : str->impl()->members()) {
+      // TODO(amaiorano): Determine the host-sharable types
+      bool can_be_host_sharable = true;
+      if (ast::IsHostSharable(sc) && !can_be_host_sharable) {
+        std::stringstream err;
+        err << "Structure '" << str->FriendlyName(builder_->Symbols())
+            << "' is used by storage class " << sc
+            << " which contains a member of non-host-sharable type "
+            << member->type()->FriendlyName(builder_->Symbols());
+        diagnostics_.add_error(err.str(), member->source());
+        return false;
+      }
+      if (!ApplyStorageClassUsageToType(sc, member->type())) {
+        return false;
+      }
+    }
+  }
+
+  if (auto* arr = ty->As<type::Array>()) {
+    return ApplyStorageClassUsageToType(sc, arr->type());
+  }
+
+  return true;
+}
+
 template <typename F>
 bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
   BlockInfo block_info(type, current_block_);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 4bbe504..7bb56b6 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -18,6 +18,7 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 #include "src/intrinsic_table.h"
@@ -124,6 +125,7 @@
     uint32_t align = 0;
     uint32_t size = 0;
     uint32_t size_no_padding = 0;
+    std::unordered_set<ast::StorageClass> storage_class_usage;
   };
 
   /// Structure holding semantic information about a block (i.e. scope), such as
@@ -206,7 +208,6 @@
   bool Statements(const ast::StatementList&);
   bool UnaryOp(ast::UnaryOpExpression*);
   bool VariableDeclStatement(const ast::VariableDeclStatement*);
-  bool VariableStorageClass(ast::Statement*);
 
   /// @returns the semantic information for the array `arr`, building it if it
   /// hasn't been constructed already. If an error is raised, nullptr is
@@ -217,6 +218,12 @@
   /// been constructed already. If an error is raised, nullptr is returned.
   StructInfo* Structure(type::Struct* str);
 
+  /// Records the storage class usage for the given type, and any transient
+  /// dependencies of the type. Validates that the type can be used for the
+  /// given storage class, erroring if it cannot.
+  /// @returns true on success, false on error
+  bool ApplyStorageClassUsageToType(ast::StorageClass, type::Type*);
+
   /// @param align the output default alignment in bytes for the type `ty`
   /// @param size the output default size in bytes for the type `ty`
   /// @returns true on success, false on error
diff --git a/src/resolver/struct_storage_class_use_test.cc b/src/resolver/struct_storage_class_use_test.cc
new file mode 100644
index 0000000..34fa7e2
--- /dev/null
+++ b/src/resolver/struct_storage_class_use_test.cc
@@ -0,0 +1,161 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/resolver/resolver.h"
+
+#include "gmock/gmock.h"
+#include "src/resolver/resolver_test_helper.h"
+#include "src/semantic/struct.h"
+
+using ::testing::UnorderedElementsAre;
+
+namespace tint {
+namespace resolver {
+namespace {
+
+using ResolverStorageClassUseTest = ResolverTest;
+
+TEST_F(ResolverStorageClassUseTest, UnreachableStruct) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_TRUE(sem->StorageClassUsage().empty());
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+
+  Global("g", s, ast::StorageClass::kStorage);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kStorage));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* a = ty.alias("A", s);
+  Global("g", a, ast::StorageClass::kStorage);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kStorage));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* o = Structure("O", {Member("a", s)});
+  Global("g", o, ast::StorageClass::kStorage);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kStorage));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* a = ty.array(s, 3);
+  Global("g", a, ast::StorageClass::kStorage);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kStorage));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+
+  WrapInFunction(Var("g", s, ast::StorageClass::kFunction));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalAlias) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* a = ty.alias("A", s);
+  WrapInFunction(Var("g", a, ast::StorageClass::kFunction));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalStruct) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* o = Structure("O", {Member("a", s)});
+  WrapInFunction(Var("g", o, ast::StorageClass::kFunction));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalArray) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* a = ty.array(s, 3);
+  WrapInFunction(Var("g", a, ast::StorageClass::kFunction));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kFunction));
+}
+
+TEST_F(ResolverStorageClassUseTest, StructMultipleStorageClassUses) {
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  Global("x", s, ast::StorageClass::kStorage);
+  Global("y", s, ast::StorageClass::kUniform);
+  WrapInFunction(Var("g", s, ast::StorageClass::kFunction));
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem = Sem().Get(s);
+  ASSERT_NE(sem, nullptr);
+  EXPECT_THAT(sem->StorageClassUsage(),
+              UnorderedElementsAre(ast::StorageClass::kStorage,
+                                   ast::StorageClass::kUniform,
+                                   ast::StorageClass::kFunction));
+}
+
+}  // namespace
+}  // namespace resolver
+}  // namespace tint