[ast] Determining subtype of StorageTextureType

It's determined at the beginning of TypeDeterminer::Determine().

Bug: tint:141
Change-Id: I761199db0c9813dbd42c6cb4ecb3532d1a39f49f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27460
Commit-Queue: Tomek Ponitka <tommek@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/type/storage_texture_type.cc b/src/ast/type/storage_texture_type.cc
index a61fc01..f3f3b9f 100644
--- a/src/ast/type/storage_texture_type.cc
+++ b/src/ast/type/storage_texture_type.cc
@@ -164,6 +164,14 @@
   assert(IsValidStorageDimension(dim));
 }
 
+void StorageTextureType::set_type(Type* const type) {
+  type_ = type;
+}
+
+Type* StorageTextureType::type() const {
+  return type_;
+}
+
 StorageTextureType::StorageTextureType(StorageTextureType&&) = default;
 
 StorageTextureType::~StorageTextureType() = default;
diff --git a/src/ast/type/storage_texture_type.h b/src/ast/type/storage_texture_type.h
index 2c27ffd..ae6c4fd 100644
--- a/src/ast/type/storage_texture_type.h
+++ b/src/ast/type/storage_texture_type.h
@@ -74,9 +74,10 @@
   /// @param dim the dimensionality of the texture
   /// @param access the access type for the texture
   /// @param format the image format of the texture
-  explicit StorageTextureType(TextureDimension dim,
-                              StorageAccess access,
-                              ImageFormat format);
+  StorageTextureType(TextureDimension dim,
+                     StorageAccess access,
+                     ImageFormat format);
+
   /// Move constructor
   StorageTextureType(StorageTextureType&&);
   ~StorageTextureType() override;
@@ -84,8 +85,11 @@
   /// @returns true if the type is a storage texture type
   bool IsStorage() const override;
 
-  /// @returns the subtype of the sampled texture
-  Type* type() const { return type_; }
+  /// @param type the subtype of the storage texture
+  void set_type(Type* const type);
+
+  /// @returns the subtype of the storage texture set with set_type
+  Type* type() const;
 
   /// @returns the storage access
   StorageAccess access() const { return storage_access_; }
diff --git a/src/ast/type/storage_texture_type_test.cc b/src/ast/type/storage_texture_type_test.cc
index 6c834a0..bc0f503 100644
--- a/src/ast/type/storage_texture_type_test.cc
+++ b/src/ast/type/storage_texture_type_test.cc
@@ -14,6 +14,9 @@
 
 #include "src/ast/type/storage_texture_type.h"
 
+#include "src/ast/identifier_expression.h"
+#include "src/type_determiner.h"
+
 #include "gtest/gtest.h"
 
 namespace tint {
@@ -72,6 +75,48 @@
   EXPECT_EQ(s.type_name(), "__storage_texture_read_2d_array_rgba32float");
 }
 
+TEST_F(StorageTextureTypeTest, F32Type) {
+  Context ctx;
+  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+      TextureDimension::k2dArray, StorageAccess::kRead,
+      ImageFormat::kRgba32Float));
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+  ASSERT_TRUE(s->IsTexture());
+  ASSERT_TRUE(s->AsTexture()->IsStorage());
+  EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsF32());
+}
+
+TEST_F(StorageTextureTypeTest, U32Type) {
+  Context ctx;
+  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+      TextureDimension::k2dArray, StorageAccess::kRead,
+      ImageFormat::kRgba8Unorm));
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+  ASSERT_TRUE(s->IsTexture());
+  ASSERT_TRUE(s->AsTexture()->IsStorage());
+  EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsU32());
+}
+
+TEST_F(StorageTextureTypeTest, I32Type) {
+  Context ctx;
+  ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>(
+      TextureDimension::k2dArray, StorageAccess::kRead,
+      ImageFormat::kRgba32Sint));
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+  ASSERT_TRUE(s->IsTexture());
+  ASSERT_TRUE(s->AsTexture()->IsStorage());
+  EXPECT_TRUE(s->AsTexture()->AsStorage()->type()->IsI32());
+}
+
 }  // namespace
 }  // namespace type
 }  // namespace ast
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index a559d8b..9f2e90d 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -41,9 +41,11 @@
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/bool_type.h"
 #include "src/ast/type/f32_type.h"
+#include "src/ast/type/i32_type.h"
 #include "src/ast/type/matrix_type.h"
 #include "src/ast/type/pointer_type.h"
 #include "src/ast/type/struct_type.h"
+#include "src/ast/type/u32_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/unary_op_expression.h"
@@ -177,6 +179,18 @@
 }
 
 bool TypeDeterminer::Determine() {
+  for (auto& iter : ctx_.type_mgr().types()) {
+    auto& type = iter.second;
+    if (!type->IsTexture() || !type->AsTexture()->IsStorage()) {
+      continue;
+    }
+    if (!DetermineStorageTextureSubtype(type->AsTexture()->AsStorage())) {
+      set_error(Source{}, "unable to determine storage texture subtype for: " +
+                              type->type_name());
+      return false;
+    }
+  }
+
   for (const auto& var : mod_->global_variables()) {
     variable_stack_.set_global(var->name(), var.get());
 
@@ -824,6 +838,67 @@
   return true;
 }
 
+bool TypeDeterminer::DetermineStorageTextureSubtype(
+    ast::type::StorageTextureType* tex) {
+  if (tex->type() != nullptr) {
+    return true;
+  }
+
+  switch (tex->image_format()) {
+    case ast::type::ImageFormat::kR8Unorm:
+    case ast::type::ImageFormat::kRg8Unorm:
+    case ast::type::ImageFormat::kRgba8Unorm:
+    case ast::type::ImageFormat::kRgba8UnormSrgb:
+    case ast::type::ImageFormat::kBgra8Unorm:
+    case ast::type::ImageFormat::kBgra8UnormSrgb:
+    case ast::type::ImageFormat::kRgb10A2Unorm:
+    case ast::type::ImageFormat::kR8Uint:
+    case ast::type::ImageFormat::kR16Uint:
+    case ast::type::ImageFormat::kRg8Uint:
+    case ast::type::ImageFormat::kR32Uint:
+    case ast::type::ImageFormat::kRg16Uint:
+    case ast::type::ImageFormat::kRgba8Uint:
+    case ast::type::ImageFormat::kRg32Uint:
+    case ast::type::ImageFormat::kRgba16Uint:
+    case ast::type::ImageFormat::kRgba32Uint: {
+      tex->set_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()));
+      return true;
+    }
+
+    case ast::type::ImageFormat::kR8Snorm:
+    case ast::type::ImageFormat::kRg8Snorm:
+    case ast::type::ImageFormat::kRgba8Snorm:
+    case ast::type::ImageFormat::kR8Sint:
+    case ast::type::ImageFormat::kR16Sint:
+    case ast::type::ImageFormat::kRg8Sint:
+    case ast::type::ImageFormat::kR32Sint:
+    case ast::type::ImageFormat::kRg16Sint:
+    case ast::type::ImageFormat::kRgba8Sint:
+    case ast::type::ImageFormat::kRg32Sint:
+    case ast::type::ImageFormat::kRgba16Sint:
+    case ast::type::ImageFormat::kRgba32Sint: {
+      tex->set_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()));
+      return true;
+    }
+
+    case ast::type::ImageFormat::kR16Float:
+    case ast::type::ImageFormat::kR32Float:
+    case ast::type::ImageFormat::kRg16Float:
+    case ast::type::ImageFormat::kRg11B10Float:
+    case ast::type::ImageFormat::kRg32Float:
+    case ast::type::ImageFormat::kRgba16Float:
+    case ast::type::ImageFormat::kRgba32Float: {
+      tex->set_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+      return true;
+    }
+  }
+
+  return false;
+}
+
 ast::type::Type* TypeDeterminer::GetImportData(
     const Source& source,
     const std::string& path,
diff --git a/src/type_determiner.h b/src/type_determiner.h
index cf8f8bd..290357d 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -19,6 +19,7 @@
 #include <unordered_map>
 
 #include "src/ast/module.h"
+#include "src/ast/type/storage_texture_type.h"
 #include "src/context.h"
 #include "src/scope_stack.h"
 
@@ -118,6 +119,8 @@
   bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
   bool DetermineUnaryOp(ast::UnaryOpExpression* expr);
 
+  bool DetermineStorageTextureSubtype(ast::type::StorageTextureType* tex);
+
   Context& ctx_;
   ast::Module* mod_;
   std::string error_;
diff --git a/src/type_manager.h b/src/type_manager.h
index eb26a25..2c82d76 100644
--- a/src/type_manager.h
+++ b/src/type_manager.h
@@ -37,10 +37,10 @@
   /// @return the pointer to the registered type
   ast::type::Type* Get(std::unique_ptr<ast::type::Type> type);
 
-  /// Returns the type map, for testing purposes.
+  /// Returns the type map
   /// @returns the mapping from name string to type.
   const std::unordered_map<std::string, std::unique_ptr<ast::type::Type>>&
-  TypesForTesting() {
+  types() {
     return types_;
   }
 
diff --git a/src/type_manager_test.cc b/src/type_manager_test.cc
index 1628ffb..217259a 100644
--- a/src/type_manager_test.cc
+++ b/src/type_manager_test.cc
@@ -57,9 +57,9 @@
   auto* t = tm.Get(std::make_unique<ast::type::I32Type>());
   ASSERT_NE(t, nullptr);
 
-  EXPECT_FALSE(tm.TypesForTesting().empty());
+  EXPECT_FALSE(tm.types().empty());
   tm.Reset();
-  EXPECT_TRUE(tm.TypesForTesting().empty());
+  EXPECT_TRUE(tm.types().empty());
 
   auto* t2 = tm.Get(std::make_unique<ast::type::I32Type>());
   ASSERT_NE(t2, nullptr);