Add Transform::CreateASTTypeFor()

Reconstructs the AST nodes needed to build the given semantic type.

Bug: tint:724
Change-Id: Iadf97a47b68088a6a1eb1e6871fb3a7248676417
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49745
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index aff37fc..064c2f9 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -564,6 +564,7 @@
     sem/intrinsic_test.cc
     symbol_table_test.cc
     symbol_test.cc
+    transform/transform_test.cc
     test_main.cc
     sem/access_control_type_test.cc
     sem/alias_type_test.cc
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index a06d75e..d971ab9 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -71,5 +71,49 @@
   return new_decorations;
 }
 
+ast::Type* Transform::CreateASTTypeFor(CloneContext* ctx, const sem::Type* ty) {
+  if (ty->Is<sem::Void>()) {
+    return ctx->dst->create<ast::Void>();
+  }
+  if (ty->Is<sem::I32>()) {
+    return ctx->dst->create<ast::I32>();
+  }
+  if (ty->Is<sem::U32>()) {
+    return ctx->dst->create<ast::U32>();
+  }
+  if (ty->Is<sem::F32>()) {
+    return ctx->dst->create<ast::F32>();
+  }
+  if (ty->Is<sem::Bool>()) {
+    return ctx->dst->create<ast::Bool>();
+  }
+  if (auto* m = ty->As<sem::Matrix>()) {
+    auto* el = CreateASTTypeFor(ctx, m->type());
+    return ctx->dst->create<ast::Matrix>(el, m->rows(), m->columns());
+  }
+  if (auto* v = ty->As<sem::Vector>()) {
+    auto* el = CreateASTTypeFor(ctx, v->type());
+    return ctx->dst->create<ast::Vector>(el, v->size());
+  }
+  if (auto* a = ty->As<sem::ArrayType>()) {
+    auto* el = CreateASTTypeFor(ctx, a->type());
+    auto decos = ctx->Clone(a->decorations());
+    return ctx->dst->create<ast::Array>(el, a->size(), std::move(decos));
+  }
+  if (auto* ac = ty->As<sem::AccessControl>()) {
+    auto* el = CreateASTTypeFor(ctx, ac->type());
+    return ctx->dst->create<ast::AccessControl>(ac->access_control(), el);
+  }
+  if (auto* a = ty->As<sem::Alias>()) {
+    return ctx->dst->create<ast::TypeName>(ctx->Clone(a->symbol()));
+  }
+  if (auto* s = ty->As<sem::StructType>()) {
+    return ctx->dst->create<ast::TypeName>(ctx->Clone(s->impl()->name()));
+  }
+  TINT_UNREACHABLE(ctx->dst->Diagnostics())
+      << "Unhandled type: " << ty->TypeInfo().name;
+  return nullptr;
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 95ba235..3506af1 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -179,6 +179,14 @@
       CloneContext* ctx,
       const ast::DecorationList& in,
       std::function<bool(const ast::Decoration*)> should_remove);
+
+  /// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
+  /// semantic type `ty`.
+  /// @param ctx the clone context
+  /// @param ty the semantic type to reconstruct
+  /// @returns a ast::Type that when resolved, will produce the semantic type
+  /// `ty`.
+  static ast::Type* CreateASTTypeFor(CloneContext* ctx, const sem::Type* ty);
 };
 
 }  // namespace transform
diff --git a/src/transform/transform_test.cc b/src/transform/transform_test.cc
new file mode 100644
index 0000000..f287f57
--- /dev/null
+++ b/src/transform/transform_test.cc
@@ -0,0 +1,122 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/transform.h"
+#include "src/clone_context.h"
+#include "src/program_builder.h"
+
+#include "gtest/gtest.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+// Inherit from Transform so we have access to protected methods
+struct CreateASTTypeForTest : public testing::Test, public Transform {
+  Output Run(const Program*, const DataMap&) override { return {}; }
+
+  ast::Type* create(
+      std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
+    ProgramBuilder sem_type_builder;
+    auto* sem_type = create_sem_type(sem_type_builder);
+    Program program(std::move(sem_type_builder));
+    CloneContext ctx(&ast_type_builder, &program, false);
+    return CreateASTTypeFor(&ctx, sem_type);
+  }
+
+  ProgramBuilder ast_type_builder;
+};
+
+TEST_F(CreateASTTypeForTest, Basic) {
+  EXPECT_TRUE(create([](ProgramBuilder& b) {
+                return b.create<sem::I32>();
+              })->Is<ast::I32>());
+  EXPECT_TRUE(create([](ProgramBuilder& b) {
+                return b.create<sem::U32>();
+              })->Is<ast::U32>());
+  EXPECT_TRUE(create([](ProgramBuilder& b) {
+                return b.create<sem::F32>();
+              })->Is<ast::F32>());
+  EXPECT_TRUE(create([](ProgramBuilder& b) {
+                return b.create<sem::Bool>();
+              })->Is<ast::Bool>());
+  EXPECT_TRUE(create([](ProgramBuilder& b) {
+                return b.create<sem::Void>();
+              })->Is<ast::Void>());
+}
+
+TEST_F(CreateASTTypeForTest, Matrix) {
+  auto* mat = create([](ProgramBuilder& b) {
+    return b.create<sem::Matrix>(b.create<sem::F32>(), 2, 3);
+  });
+  ASSERT_TRUE(mat->Is<ast::Matrix>());
+  ASSERT_TRUE(mat->As<ast::Matrix>()->type()->Is<ast::F32>());
+  ASSERT_EQ(mat->As<ast::Matrix>()->columns(), 3u);
+  ASSERT_EQ(mat->As<ast::Matrix>()->rows(), 2u);
+}
+
+TEST_F(CreateASTTypeForTest, Vector) {
+  auto* vec = create([](ProgramBuilder& b) {
+    return b.create<sem::Vector>(b.create<sem::F32>(), 2);
+  });
+  ASSERT_TRUE(vec->Is<ast::Vector>());
+  ASSERT_TRUE(vec->As<ast::Vector>()->type()->Is<ast::F32>());
+  ASSERT_EQ(vec->As<ast::Vector>()->size(), 2u);
+}
+
+TEST_F(CreateASTTypeForTest, Array) {
+  auto* arr = create([](ProgramBuilder& b) {
+    return b.create<sem::ArrayType>(b.create<sem::F32>(), 4,
+                                    ast::DecorationList{
+                                        b.create<ast::StrideDecoration>(32u),
+                                    });
+  });
+  ASSERT_TRUE(arr->Is<ast::Array>());
+  ASSERT_TRUE(arr->As<ast::Array>()->type()->Is<ast::F32>());
+  ASSERT_EQ(arr->As<ast::Array>()->size(), 4u);
+  ASSERT_EQ(arr->As<ast::Array>()->decorations().size(), 1u);
+  ASSERT_TRUE(
+      arr->As<ast::Array>()->decorations()[0]->Is<ast::StrideDecoration>());
+  ASSERT_EQ(arr->As<ast::Array>()
+                ->decorations()[0]
+                ->As<ast::StrideDecoration>()
+                ->stride(),
+            32u);
+}
+
+TEST_F(CreateASTTypeForTest, AccessControl) {
+  auto* ac = create([](ProgramBuilder& b) {
+    auto str = b.Structure("S", {}, {});
+    return b.create<sem::AccessControl>(ast::AccessControl::kReadOnly, str);
+  });
+  ASSERT_TRUE(ac->Is<ast::AccessControl>());
+  EXPECT_EQ(ac->As<ast::AccessControl>()->access_control(),
+            ast::AccessControl::kReadOnly);
+  EXPECT_TRUE(ac->As<ast::AccessControl>()->type()->Is<ast::TypeName>());
+}
+
+TEST_F(CreateASTTypeForTest, Struct) {
+  auto* str = create([](ProgramBuilder& b) {
+    auto* impl = b.Structure("S", {}, {}).ast;
+    return b.create<sem::StructType>(const_cast<ast::Struct*>(impl));
+  });
+  ASSERT_TRUE(str->Is<ast::TypeName>());
+  EXPECT_EQ(
+      ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name()),
+      "S");
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index e999f6c54..c7bed03 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -297,6 +297,7 @@
     "../src/transform/first_index_offset_test.cc",
     "../src/transform/renamer_test.cc",
     "../src/transform/single_entry_point_test.cc",
+    "../src/transform/transform_test.cc",
     "../src/transform/vertex_pulling_test.cc",
     "../src/utils/command_test.cc",
     "../src/utils/get_or_create_test.cc",