resolver: Validate uniform buffer types

Fixed: tint:210
Change-Id: I7763ca23a5dce09755a1ca71d35f24897a875ac0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48604
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/ast/module_clone_test.cc b/src/ast/module_clone_test.cc
index 6f39c42..6b61483 100644
--- a/src/ast/module_clone_test.cc
+++ b/src/ast/module_clone_test.cc
@@ -39,7 +39,7 @@
 type t0 = [[stride(16)]] array<vec4<f32>>;
 type t1 = array<vec4<f32>>;
 
-var<uniform> g0 : u32 = 20u;
+var<private> g0 : u32 = 20u;
 var<private> g1 : f32 = 123.0;
 var g2 : texture_2d<f32>;
 var g3 : [[access(read)]] texture_storage_2d<r32uint>;
@@ -47,7 +47,7 @@
 var g5 : [[access(read)]] texture_storage_2d<r32uint>;
 var g6 : [[access(write)]] texture_storage_2d<rg32float>;
 
-[[builtin(position)]] var<uniform> g7 : vec3<f32>;
+var<private> g7 : vec3<f32>;
 [[group(10), binding(20)]] var<storage> g8 : [[access(write)]] S;
 [[group(10), binding(20)]] var<storage> g9 : [[access(read)]] S;
 [[group(10), binding(20)]] var<storage> g10 : [[access(read_write)]] S;
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index a14baa1..eeb0ee7 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -1822,31 +1822,6 @@
   EXPECT_TRUE(error.find("not an entry point") != std::string::npos);
 }
 
-TEST_F(InspectorGetUniformBufferResourceBindingsTest, MissingBlockDeco) {
-  ast::DecorationList decos;
-  auto* str = create<ast::Struct>(
-      Sym("foo_type"),
-      ast::StructMemberList{Member(StructMemberName(0, ty.i32()), ty.i32())},
-      decos);
-
-  auto* foo_type = ty.struct_(str);
-  AddUniformBuffer("foo_ub", foo_type, 0, 0);
-
-  MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}});
-
-  MakeCallerBodyFunction(
-      "ep_func", {"ub_func"},
-      ast::DecorationList{
-          create<ast::StageDecoration>(ast::PipelineStage::kFragment),
-      });
-
-  Inspector& inspector = Build();
-
-  auto result = inspector.GetUniformBufferResourceBindings("ep_func");
-  ASSERT_FALSE(inspector.has_error()) << inspector.error();
-  EXPECT_EQ(0u, result.size());
-}
-
 TEST_F(InspectorGetUniformBufferResourceBindingsTest, Simple) {
   sem::StructType* foo_struct_type =
       MakeUniformBufferType("foo_type", {ty.i32()});
diff --git a/src/program_builder.h b/src/program_builder.h
index 2c61450..c006792 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -985,13 +985,42 @@
                                  decorations);
   }
 
-  /// @param args the arguments to pass to Var()
-  /// @returns a `ast::Variable` constructed by calling Var() with the arguments
-  /// of `args`, which is automatically registered as a global variable with the
-  /// ast::Module.
-  template <typename... ARGS>
-  ast::Variable* Global(ARGS&&... args) {
-    auto* var = Var(std::forward<ARGS>(args)...);
+  /// @param name the variable name
+  /// @param type the variable type
+  /// @param storage the variable storage class
+  /// @param constructor constructor expression
+  /// @param decorations variable decorations
+  /// @returns a new `ast::Variable`, which is automatically registered as a
+  /// global variable with the ast::Module.
+  template <typename NAME>
+  ast::Variable* Global(NAME&& name,
+                        sem::Type* type,
+                        ast::StorageClass storage,
+                        ast::Expression* constructor = nullptr,
+                        ast::DecorationList decorations = {}) {
+    auto* var =
+        Var(std::forward<NAME>(name), type, storage, constructor, decorations);
+    AST().AddGlobalVariable(var);
+    return var;
+  }
+
+  /// @param source the variable source
+  /// @param name the variable name
+  /// @param type the variable type
+  /// @param storage the variable storage class
+  /// @param constructor constructor expression
+  /// @param decorations variable decorations
+  /// @returns a new `ast::Variable`, which is automatically registered as a
+  /// global variable with the ast::Module.
+  template <typename NAME>
+  ast::Variable* Global(const Source& source,
+                        NAME&& name,
+                        sem::Type* type,
+                        ast::StorageClass storage,
+                        ast::Expression* constructor = nullptr,
+                        ast::DecorationList decorations = {}) {
+    auto* var = Var(source, std::forward<NAME>(name), type, storage,
+                    constructor, decorations);
     AST().AddGlobalVariable(var);
     return var;
   }
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 084fe7c..6532144 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -432,37 +432,69 @@
 }
 
 bool Resolver::ValidateGlobalVariable(const VariableInfo* info) {
-  if (info->storage_class == ast::StorageClass::kStorage) {
-    // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
-    // Variables in the storage storage class and variables with a storage
-    // texture type must have an access attribute applied to the store type.
+  switch (info->storage_class) {
+    case ast::StorageClass::kStorage: {
+      // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
+      // Variables in the storage storage class and variables with a storage
+      // texture type must have an access attribute applied to the store type.
 
-    // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
-    // A variable in the storage storage class is a storage buffer variable. Its
-    // store type must be a host-shareable structure type with block attribute,
-    // satisfying the storage class constraints.
+      // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+      // A variable in the storage storage class is a storage buffer variable.
+      // Its store type must be a host-shareable structure type with block
+      // attribute, satisfying the storage class constraints.
 
-    auto* access = info->type->As<sem::AccessControl>();
-    auto* str = access ? access->type()->As<sem::StructType>() : nullptr;
-    if (!str) {
-      diagnostics_.add_error(
-          "variables declared in the <storage> storage class must be of an "
-          "[[access]] qualified structure type",
-          info->declaration->source());
-      return false;
-    }
-
-    if (!str->IsBlockDecorated()) {
-      diagnostics_.add_error(
-          "structure used as a storage buffer must be declared with the "
-          "[[block]] decoration",
-          str->impl()->source());
-      if (info->declaration->source().range.begin.line) {
-        diagnostics_.add_note("structure used as storage buffer here",
-                              info->declaration->source());
+      auto* access = info->type->As<sem::AccessControl>();
+      auto* str = access ? access->type()->As<sem::StructType>() : nullptr;
+      if (!str) {
+        diagnostics_.add_error(
+            "variables declared in the <storage> storage class must be of an "
+            "[[access]] qualified structure type",
+            info->declaration->source());
+        return false;
       }
-      return false;
+
+      if (!str->IsBlockDecorated()) {
+        diagnostics_.add_error(
+            "structure used as a storage buffer must be declared with the "
+            "[[block]] decoration",
+            str->impl()->source());
+        if (info->declaration->source().range.begin.line) {
+          diagnostics_.add_note("structure used as storage buffer here",
+                                info->declaration->source());
+        }
+        return false;
+      }
+      break;
     }
+    case ast::StorageClass::kUniform: {
+      // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+      // A variable in the uniform storage class is a uniform buffer variable.
+      // Its store type must be a host-shareable structure type with block
+      // attribute, satisfying the storage class constraints.
+      auto* str = info->type->As<sem::StructType>();
+      if (!str) {
+        diagnostics_.add_error(
+            "variables declared in the <uniform> storage class must be of a "
+            "structure type",
+            info->declaration->source());
+        return false;
+      }
+
+      if (!str->IsBlockDecorated()) {
+        diagnostics_.add_error(
+            "structure used as a uniform buffer must be declared with the "
+            "[[block]] decoration",
+            str->impl()->source());
+        if (info->declaration->source().range.begin.line) {
+          diagnostics_.add_note("structure used as uniform buffer here",
+                                info->declaration->source());
+        }
+        return false;
+      }
+      break;
+    }
+    default:
+      break;
   }
 
   return ValidateVariable(info->declaration);
diff --git a/src/resolver/storage_class_validation_test.cc b/src/resolver/storage_class_validation_test.cc
index 931b23a..1ce3cbf 100644
--- a/src/resolver/storage_class_validation_test.cc
+++ b/src/resolver/storage_class_validation_test.cc
@@ -138,6 +138,94 @@
   ASSERT_TRUE(r()->Resolve());
 }
 
+///
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) {
+  // var<uniform> g : bool;
+  Global(Source{{56, 78}}, "g", ty.bool_(), ast::StorageClass::kUniform);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
+  // var<uniform> g : ptr<i32, input>;
+  Global(Source{{56, 78}}, "g", ty.pointer<i32>(ast::StorageClass::kInput),
+         ast::StorageClass::kUniform);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
+  // var<uniform> g : [[access(read)]] array<S, 3>;
+  auto* s = Structure("S", {Member("a", ty.f32())});
+  auto* a = ty.array(s, 3);
+  auto* ac = ty.access(ast::AccessControl::kReadOnly, a);
+  Global(Source{{56, 78}}, "g", ac, ast::StorageClass::kUniform);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) {
+  // type a = bool;
+  // var<uniform> g : [[access(read)]] a;
+  auto* a = ty.alias("a", ty.bool_());
+  Global(Source{{56, 78}}, "g", a, ast::StorageClass::kUniform);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoBlockDecoration) {
+  // struct S { x : i32 };
+  // var<uniform> g : S;
+  auto* s = Structure(Source{{12, 34}}, "S", {Member("x", ty.i32())});
+  Global(Source{{56, 78}}, "g", s, ast::StorageClass::kUniform);
+
+  ASSERT_FALSE(r()->Resolve());
+
+  EXPECT_EQ(
+      r()->error(),
+      R"(12:34 error: structure used as a uniform buffer must be declared with the [[block]] decoration
+56:78 note: structure used as uniform buffer here)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Basic) {
+  // [[block]] struct S { x : i32 };
+  // var<uniform> g :  S;
+  auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+                      {create<ast::StructBlockDecoration>()});
+  Global(Source{{56, 78}}, "g", s, ast::StorageClass::kUniform);
+
+  ASSERT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Aliases) {
+  // [[block]] struct S { x : i32 };
+  // type a1 = S;
+  // var<uniform> g : a1;
+  auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+                      {create<ast::StructBlockDecoration>()});
+  auto* a1 = ty.alias("a1", s);
+  Global(Source{{56, 78}}, "g", a1, ast::StorageClass::kUniform);
+
+  ASSERT_TRUE(r()->Resolve());
+}
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/struct_storage_class_use_test.cc b/src/resolver/struct_storage_class_use_test.cc
index 4751992..a34a835 100644
--- a/src/resolver/struct_storage_class_use_test.cc
+++ b/src/resolver/struct_storage_class_use_test.cc
@@ -66,53 +66,53 @@
 TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) {
   auto* s = Structure("S", {Member("a", ty.f32())});
 
-  Global("g", s, ast::StorageClass::kUniform);
+  Global("g", s, ast::StorageClass::kPrivate);
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
 
   auto* sem = Sem().Get(s);
   ASSERT_NE(sem, nullptr);
   EXPECT_THAT(sem->StorageClassUsage(),
-              UnorderedElementsAre(ast::StorageClass::kUniform));
+              UnorderedElementsAre(ast::StorageClass::kPrivate));
 }
 
 TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) {
   auto* s = Structure("S", {Member("a", ty.f32())});
   auto* a = ty.alias("A", s);
-  Global("g", a, ast::StorageClass::kUniform);
+  Global("g", a, ast::StorageClass::kPrivate);
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
 
   auto* sem = Sem().Get(s);
   ASSERT_NE(sem, nullptr);
   EXPECT_THAT(sem->StorageClassUsage(),
-              UnorderedElementsAre(ast::StorageClass::kUniform));
+              UnorderedElementsAre(ast::StorageClass::kPrivate));
 }
 
 TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
   auto* s = Structure("S", {Member("a", ty.f32())});
   auto* o = Structure("O", {Member("a", s)});
-  Global("g", o, ast::StorageClass::kUniform);
+  Global("g", o, ast::StorageClass::kPrivate);
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
 
   auto* sem = Sem().Get(s);
   ASSERT_NE(sem, nullptr);
   EXPECT_THAT(sem->StorageClassUsage(),
-              UnorderedElementsAre(ast::StorageClass::kUniform));
+              UnorderedElementsAre(ast::StorageClass::kPrivate));
 }
 
 TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
   auto* s = Structure("S", {Member("a", ty.f32())});
   auto* a = ty.array(s, 3);
-  Global("g", a, ast::StorageClass::kUniform);
+  Global("g", a, ast::StorageClass::kPrivate);
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
 
   auto* sem = Sem().Get(s);
   ASSERT_NE(sem, nullptr);
   EXPECT_THAT(sem->StorageClassUsage(),
-              UnorderedElementsAre(ast::StorageClass::kUniform));
+              UnorderedElementsAre(ast::StorageClass::kPrivate));
 }
 
 TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) {
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index e3f733b..0c3cd2a 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -31,10 +31,9 @@
 
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
   Func("my_func", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
-       },
-       ast::DecorationList{});
+       {
+           Return(),
+       });
 
   GeneratorImpl& gen = Build();
 
@@ -50,10 +49,9 @@
 
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
   Func("GeometryShader", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
-       },
-       ast::DecorationList{});
+       {
+           Return(),
+       });
 
   GeneratorImpl& gen = SanitizeAndBuild();
 
@@ -68,10 +66,9 @@
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
   Func("my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())},
        ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
-       },
-       ast::DecorationList{});
+       {
+           Return(),
+       });
 
   GeneratorImpl& gen = Build();
 
@@ -87,10 +84,9 @@
 
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoint_NoReturn_Void) {
-  Func("main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{/* no explicit return */},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+  Func("main", ast::VariableList{}, ty.void_(), {/* no explicit return */},
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -111,9 +107,8 @@
   //   return foo;
   // }
   auto* foo_in = Param("foo", ty.f32(), {create<ast::LocationDecoration>(0)});
-  Func("frag_main", ast::VariableList{foo_in}, ty.f32(),
-       {create<ast::ReturnStatement>(Expr("foo"))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+  Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return(Expr("foo"))},
+       {Stage(ast::PipelineStage::kFragment)},
        {create<ast::LocationDecoration>(1)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -146,8 +141,8 @@
       Param("coord", ty.vec4<f32>(),
             {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
   Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
-       {create<ast::ReturnStatement>(MemberAccessor("coord", "x"))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+       {Return(MemberAccessor("coord", "x"))},
+       {Stage(ast::PipelineStage::kFragment)},
        {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -201,14 +196,12 @@
 
   Func("frag_main", {Param("inputs", interface_struct)}, ty.void_(),
        {
-           WrapInStatement(
-               Const("r", ty.f32(), MemberAccessor(Expr("inputs"), "col1"))),
-           WrapInStatement(
-               Const("g", ty.f32(), MemberAccessor(Expr("inputs"), "col2"))),
-           WrapInStatement(Const("p", ty.vec4<f32>(),
-                                 MemberAccessor(Expr("inputs"), "pos"))),
+           Decl(Const("r", ty.f32(), MemberAccessor(Expr("inputs"), "col1"))),
+           Decl(Const("g", ty.f32(), MemberAccessor(Expr("inputs"), "col2"))),
+           Decl(Const("p", ty.vec4<f32>(),
+                      MemberAccessor(Expr("inputs"), "pos"))),
        },
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+       {Stage(ast::PipelineStage::kFragment)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
 
@@ -268,20 +261,19 @@
               {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)})});
 
   Func("foo", {Param("x", ty.f32())}, vertex_output_struct,
-       {create<ast::ReturnStatement>(Construct(
-           vertex_output_struct, Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
-                                           Expr("x"), Expr(1.f))))},
+       {Return(Construct(vertex_output_struct,
+                         Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
+                                   Expr("x"), Expr(1.f))))},
        {});
 
   Func("vert_main1", {}, vertex_output_struct,
-       {create<ast::ReturnStatement>(
-           Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+       {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
+       {Stage(ast::PipelineStage::kVertex)});
 
-  Func("vert_main2", {}, vertex_output_struct,
-       {create<ast::ReturnStatement>(
-           Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+  Func(
+      "vert_main2", {}, vertex_output_struct,
+      {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
+      {Stage(ast::PipelineStage::kVertex)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
 
@@ -320,33 +312,48 @@
 
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoint_With_Uniform) {
-  Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
-         ast::DecorationList{
-             create<ast::BindingDecoration>(0),
-             create<ast::GroupDecoration>(1),
-         });
+  auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
+                           {create<ast::StructBlockDecoration>()});
+  auto* ubo = Global(
+      "ubo", ubo_ty, ast::StorageClass::kUniform, nullptr,
+      {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
-  auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
-                  MemberAccessor("coord", "x"));
-
-  Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+  Func("sub_func",
+       {
+           Param("param", ty.f32()),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       ty.f32(),
+       {
+           Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
+       });
+
+  auto* var =
+      Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
+
+  Func("frag_main", {}, ty.void_(),
+       {
+           Decl(var),
+           Return(),
+       },
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
 
   ASSERT_TRUE(gen.Generate(out)) << gen.error();
-  EXPECT_EQ(result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
+  EXPECT_EQ(result(), R"(struct UBO {
   float4 coord;
 };
 
+ConstantBuffer<UBO> ubo : register(b0, space1);
+
+float sub_func(float param) {
+  return ubo.coord.x;
+}
+
 void frag_main() {
-  float v = coord.x;
+  float v = sub_func(1.0f);
   return;
 }
 
@@ -357,10 +364,11 @@
 
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoint_With_UniformStruct) {
-  auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())});
+  auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())},
+                      {create<ast::StructBlockDecoration>()});
 
   Global("uniforms", s, ast::StorageClass::kUniform, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
@@ -370,12 +378,12 @@
                       MemberAccessor("uniforms", "coord"), Expr("x")));
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -409,7 +417,7 @@
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
@@ -418,12 +426,12 @@
                   MemberAccessor("coord", "b"));
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -455,7 +463,7 @@
   sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
@@ -464,12 +472,12 @@
                   MemberAccessor("coord", "b"));
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -501,19 +509,19 @@
   sem::AccessControl ac(ast::AccessControl::kWriteOnly, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(MemberAccessor("coord", "b"),
                                             Expr(2.0f)),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -545,19 +553,19 @@
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(MemberAccessor("coord", "b"),
                                             Expr(2.0f)),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -582,36 +590,35 @@
     HlslGeneratorImplTest_Function,
     Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) {  // NOLINT
   Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::LocationDecoration>(0),
          });
 
   Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::LocationDecoration>(1),
          });
 
   Global("val", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::LocationDecoration>(0),
          });
 
   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
            create<ast::AssignmentStatement>(Expr("val"), Expr("param")),
-           create<ast::ReturnStatement>(Expr("foo")),
-       },
-       ast::DecorationList{});
+           Return(Expr("foo")),
+       });
 
   Func(
       "ep_1", ast::VariableList{}, ty.void_(),
-      ast::StatementList{
+      {
           create<ast::AssignmentStatement>(Expr("bar"), Call("sub_func", 1.0f)),
-          create<ast::ReturnStatement>(),
+          Return(),
       },
-      ast::DecorationList{
-          create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+      {
+          Stage(ast::PipelineStage::kFragment),
       });
 
   GeneratorImpl& gen = Build();
@@ -646,24 +653,23 @@
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) {
   Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
          });
 
   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(Expr("param")),
-       },
-       ast::DecorationList{});
+       {
+           Return(Expr("param")),
+       });
 
   Func("ep_1", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(Expr("depth"),
                                             Call("sub_func", 1.0f)),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -693,31 +699,30 @@
     HlslGeneratorImplTest_Function,
     Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {  // NOLINT
   Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BuiltinDecoration>(ast::Builtin::kPosition),
          });
 
   Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
          });
 
   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(Expr("depth"),
                                             MemberAccessor("coord", "x")),
-           create<ast::ReturnStatement>(Expr("param")),
-       },
-       ast::DecorationList{});
+           Return(Expr("param")),
+       });
 
   Func("ep_1", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::AssignmentStatement>(Expr("depth"),
                                             Call("sub_func", 1.0f)),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -749,37 +754,40 @@
 
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
-  Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
-         ast::DecorationList{
+  auto* s = Structure("S", {Member("x", ty.f32())},
+                      {create<ast::StructBlockDecoration>()});
+  Global("coord", s, ast::StorageClass::kUniform, nullptr,
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
 
   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
-       },
-       ast::DecorationList{});
+       {
+           Return(MemberAccessor("coord", "x")),
+       });
 
   auto* var =
       Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
 
   ASSERT_TRUE(gen.Generate(out)) << gen.error();
-  EXPECT_EQ(result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
-  float4 coord;
+  EXPECT_EQ(result(), R"(struct S {
+  float x;
 };
 
+ConstantBuffer<S> coord : register(b0, space1);
+
 float sub_func(float param) {
   return coord.x;
 }
@@ -800,27 +808,26 @@
                       {create<ast::StructBlockDecoration>()});
   auto* ac = ty.access(ast::AccessControl::kReadWrite, s);
   Global("coord", ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(1),
          });
 
   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
-       },
-       ast::DecorationList{});
+       {
+           Return(MemberAccessor("coord", "x")),
+       });
 
   auto* var =
       Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
 
   Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
+       {
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -847,25 +854,21 @@
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) {
   Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
+         {
              create<ast::LocationDecoration>(1),
          });
 
-  auto* list = create<ast::BlockStatement>(ast::StatementList{
-      create<ast::ReturnStatement>(),
-  });
-
   Func(
       "ep_1", ast::VariableList{}, ty.void_(),
-      ast::StatementList{
+      {
           create<ast::AssignmentStatement>(Expr("bar"), Expr(1.0f)),
           create<ast::IfStatement>(create<ast::BinaryExpression>(
                                        ast::BinaryOp::kEqual, Expr(1), Expr(1)),
-                                   list, ast::ElseStatementList{}),
-          create<ast::ReturnStatement>(),
+                                   Block(Return()), ast::ElseStatementList{}),
+          Return(),
       },
-      ast::DecorationList{
-          create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+      {
+          Stage(ast::PipelineStage::kFragment),
       });
 
   GeneratorImpl& gen = Build();
@@ -891,9 +894,9 @@
 
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoint_WithNameCollision) {
-  Func("GeometryShader", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+  Func("GeometryShader", ast::VariableList{}, ty.void_(), {},
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -908,11 +911,11 @@
 
 TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_Compute) {
   Func("main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
+       {
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+       {
+           Stage(ast::PipelineStage::kCompute),
        });
 
   GeneratorImpl& gen = Build();
@@ -931,11 +934,11 @@
 TEST_F(HlslGeneratorImplTest_Function,
        Emit_Decoration_EntryPoint_Compute_WithWorkgroup) {
   Func("main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
+       {
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+       {
+           Stage(ast::PipelineStage::kCompute),
            create<ast::WorkgroupDecoration>(2u, 4u, 6u),
        });
 
@@ -954,8 +957,8 @@
 
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
   Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
+       {
+           Return(),
        });
 
   GeneratorImpl& gen = Build();
@@ -990,14 +993,13 @@
   //   return;
   // }
 
-  auto* s =
-      Structure("Data", {Member("d", ty.f32())},
-                ast::DecorationList{create<ast::StructBlockDecoration>()});
+  auto* s = Structure("Data", {Member("d", ty.f32())},
+                      {create<ast::StructBlockDecoration>()});
 
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("data", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{
+         {
              create<ast::BindingDecoration>(0),
              create<ast::GroupDecoration>(0),
          });
@@ -1007,12 +1009,12 @@
                     MemberAccessor("data", "d"));
 
     Func("a", ast::VariableList{}, ty.void_(),
-         ast::StatementList{
+         {
              create<ast::VariableDeclStatement>(var),
-             create<ast::ReturnStatement>(),
+             Return(),
          },
-         ast::DecorationList{
-             create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+         {
+             Stage(ast::PipelineStage::kCompute),
          });
   }
 
@@ -1021,12 +1023,12 @@
                     MemberAccessor("data", "d"));
 
     Func("b", ast::VariableList{}, ty.void_(),
-         ast::StatementList{
+         {
              create<ast::VariableDeclStatement>(var),
-             create<ast::ReturnStatement>(),
+             Return(),
          },
-         ast::DecorationList{
-             create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+         {
+             Stage(ast::PipelineStage::kCompute),
          });
   }
 
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 473d321..e5c70a5 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -28,9 +28,9 @@
 TEST_F(MslGeneratorImplTest, Emit_Function) {
   Func("my_func", ast::VariableList{}, ty.void_(),
        ast::StatementList{
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{});
+       {});
 
   GeneratorImpl& gen = Build();
 
@@ -54,9 +54,9 @@
 
   Func("my_func", params, ty.void_(),
        ast::StatementList{
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{});
+       {});
 
   GeneratorImpl& gen = Build();
 
@@ -76,8 +76,7 @@
 TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_NoReturn_Void) {
   Func("main", ast::VariableList{}, ty.void_(),
        ast::StatementList{/* no explicit return */},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+       {Stage(ast::PipelineStage::kFragment)});
 
   GeneratorImpl& gen = Build();
 
@@ -97,9 +96,8 @@
   //   return foo;
   // }
   auto* foo_in = Param("foo", ty.f32(), {create<ast::LocationDecoration>(0)});
-  Func("frag_main", ast::VariableList{foo_in}, ty.f32(),
-       {create<ast::ReturnStatement>(Expr("foo"))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+  Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return(Expr("foo"))},
+       {Stage(ast::PipelineStage::kFragment)},
        {create<ast::LocationDecoration>(1)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -133,8 +131,8 @@
       Param("coord", ty.vec4<f32>(),
             {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
   Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
-       {create<ast::ReturnStatement>(MemberAccessor("coord", "x"))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+       {Return(MemberAccessor("coord", "x"))},
+       {Stage(ast::PipelineStage::kFragment)},
        {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
@@ -194,7 +192,7 @@
            WrapInStatement(
                Const("g", ty.f32(), MemberAccessor(Expr("colors"), "col2"))),
        },
-       {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+       {Stage(ast::PipelineStage::kFragment)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
 
@@ -255,20 +253,19 @@
               {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)})});
 
   Func("foo", {Param("x", ty.f32())}, vertex_output_struct,
-       {create<ast::ReturnStatement>(Construct(
-           vertex_output_struct, Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
-                                           Expr("x"), Expr(1.f))))},
+       {Return(Construct(vertex_output_struct,
+                         Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
+                                   Expr("x"), Expr(1.f))))},
        {});
 
   Func("vert_main1", {}, vertex_output_struct,
-       {create<ast::ReturnStatement>(
-           Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+       {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
+       {Stage(ast::PipelineStage::kVertex)});
 
-  Func("vert_main2", {}, vertex_output_struct,
-       {create<ast::ReturnStatement>(
-           Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
-       {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+  Func(
+      "vert_main2", {}, vertex_output_struct,
+      {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
+      {Stage(ast::PipelineStage::kVertex)});
 
   GeneratorImpl& gen = SanitizeAndBuild();
 
@@ -317,8 +314,7 @@
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(1)});
+         {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
   auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
                   MemberAccessor("coord", "b"));
@@ -326,10 +322,10 @@
   Func("frag_main", ast::VariableList{}, ty.void_(),
        ast::StatementList{
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -363,8 +359,7 @@
   sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(1)});
+         {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
   auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
                   MemberAccessor("coord", "b"));
@@ -372,10 +367,10 @@
   Func("frag_main", ast::VariableList{}, ty.void_(),
        ast::StatementList{
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -402,13 +397,13 @@
     MslGeneratorImplTest,
     Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) {  // NOLINT
   Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
-         ast::DecorationList{create<ast::LocationDecoration>(0)});
+         {create<ast::LocationDecoration>(0)});
 
   Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{create<ast::LocationDecoration>(1)});
+         {create<ast::LocationDecoration>(1)});
 
   Global("val", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{create<ast::LocationDecoration>(0)});
+         {create<ast::LocationDecoration>(0)});
 
   ast::VariableList params;
   params.push_back(Param("param", ty.f32()));
@@ -416,18 +411,18 @@
   auto body = ast::StatementList{
       create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
       create<ast::AssignmentStatement>(Expr("val"), Expr("param")),
-      create<ast::ReturnStatement>(Expr("foo"))};
+      Return(Expr("foo"))};
 
-  Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+  Func("sub_func", params, ty.f32(), body, {});
 
   body = ast::StatementList{
       create<ast::AssignmentStatement>(Expr("bar"), Call("sub_func", 1.0f)),
-      create<ast::ReturnStatement>(),
+      Return(),
   };
 
   Func("ep_1", ast::VariableList{}, ty.void_(), body,
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -464,26 +459,25 @@
 TEST_F(MslGeneratorImplTest,
        Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) {
   Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
-             create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
+         {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
 
   ast::VariableList params;
   params.push_back(Param("param", ty.f32()));
 
   Func("sub_func", params, ty.f32(),
        ast::StatementList{
-           create<ast::ReturnStatement>(Expr("param")),
+           Return(Expr("param")),
        },
-       ast::DecorationList{});
+       {});
 
   auto body = ast::StatementList{
       create<ast::AssignmentStatement>(Expr("depth"), Call("sub_func", 1.0f)),
-      create<ast::ReturnStatement>(),
+      Return(),
   };
 
   Func("ep_1", ast::VariableList{}, ty.void_(), body,
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -514,12 +508,10 @@
     MslGeneratorImplTest,
     Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) {  // NOLINT
   Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
-         ast::DecorationList{
-             create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
+         {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
 
   Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{
-             create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
+         {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
 
   ast::VariableList params;
   params.push_back(Param("param", ty.f32()));
@@ -527,19 +519,19 @@
   auto body = ast::StatementList{
       create<ast::AssignmentStatement>(Expr("depth"),
                                        MemberAccessor("coord", "x")),
-      create<ast::ReturnStatement>(Expr("param")),
+      Return(Expr("param")),
   };
 
-  Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+  Func("sub_func", params, ty.f32(), body, {});
 
   body = ast::StatementList{
       create<ast::AssignmentStatement>(Expr("depth"), Call("sub_func", 1.0f)),
-      create<ast::ReturnStatement>(),
+      Return(),
   };
 
   Func("ep_1", ast::VariableList{}, ty.void_(), body,
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -568,29 +560,31 @@
 
 TEST_F(MslGeneratorImplTest,
        Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
-  Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(1)});
+  auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
+                           {create<ast::StructBlockDecoration>()});
+  auto* ubo = Global(
+      "ubo", ubo_ty, ast::StorageClass::kUniform, nullptr,
+      {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
-  ast::VariableList params;
-  params.push_back(Param("param", ty.f32()));
-
-  auto body = ast::StatementList{
-      create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
-  };
-
-  Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+  Func("sub_func",
+       {
+           Param("param", ty.f32()),
+       },
+       ty.f32(),
+       {
+           Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
+       });
 
   auto* var =
       Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
 
-  Func("frag_main", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+  Func("frag_main", {}, ty.void_(),
+       {
+           Decl(var),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -599,12 +593,16 @@
   EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
 
 using namespace metal;
-float sub_func(constant float4& coord, float param) {
-  return coord.x;
+struct UBO {
+  /* 0x0000 */ packed_float4 coord;
+};
+
+float sub_func(constant UBO& ubo, float param) {
+  return ubo.coord.x;
 }
 
-fragment void frag_main(constant float4& coord [[buffer(0)]]) {
-  float v = sub_func(coord, 1.0f);
+fragment void frag_main(constant UBO& ubo [[buffer(0)]]) {
+  float v = sub_func(ubo, 1.0f);
   return;
 }
 
@@ -623,16 +621,14 @@
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(1)});
+         {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
   ast::VariableList params;
   params.push_back(Param("param", ty.f32()));
 
-  auto body = ast::StatementList{
-      create<ast::ReturnStatement>(MemberAccessor("coord", "b"))};
+  auto body = ast::StatementList{Return(MemberAccessor("coord", "b"))};
 
-  Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+  Func("sub_func", params, ty.f32(), body, {});
 
   auto* var =
       Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
@@ -640,10 +636,10 @@
   Func("frag_main", ast::VariableList{}, ty.void_(),
        ast::StatementList{
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -681,16 +677,14 @@
   sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
 
   Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(1)});
+         {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
 
   ast::VariableList params;
   params.push_back(Param("param", ty.f32()));
 
-  auto body = ast::StatementList{
-      create<ast::ReturnStatement>(MemberAccessor("coord", "b"))};
+  auto body = ast::StatementList{Return(MemberAccessor("coord", "b"))};
 
-  Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+  Func("sub_func", params, ty.f32(), body, {});
 
   auto* var =
       Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
@@ -698,10 +692,10 @@
   Func("frag_main", ast::VariableList{}, ty.void_(),
        ast::StatementList{
            create<ast::VariableDeclStatement>(var),
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -731,10 +725,10 @@
 TEST_F(MslGeneratorImplTest,
        Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) {
   Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
-         ast::DecorationList{create<ast::LocationDecoration>(1)});
+         {create<ast::LocationDecoration>(1)});
 
   auto* list = create<ast::BlockStatement>(ast::StatementList{
-      create<ast::ReturnStatement>(),
+      Return(),
   });
 
   auto body = ast::StatementList{
@@ -742,12 +736,12 @@
       create<ast::IfStatement>(create<ast::BinaryExpression>(
                                    ast::BinaryOp::kEqual, Expr(1), Expr(1)),
                                list, ast::ElseStatementList{}),
-      create<ast::ReturnStatement>(),
+      Return(),
   };
 
   Func("ep_1", ast::VariableList{}, ty.void_(), body,
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       {
+           Stage(ast::PipelineStage::kFragment),
        });
 
   GeneratorImpl& gen = Build();
@@ -778,9 +772,9 @@
 
   Func("my_func", params, ty.void_(),
        ast::StatementList{
-           create<ast::ReturnStatement>(),
+           Return(),
        },
-       ast::DecorationList{});
+       {});
 
   GeneratorImpl& gen = Build();
 
@@ -821,8 +815,7 @@
   sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
 
   Global("data", &ac, ast::StorageClass::kStorage, nullptr,
-         ast::DecorationList{create<ast::BindingDecoration>(0),
-                             create<ast::GroupDecoration>(0)});
+         {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(0)});
 
   {
     auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
@@ -831,10 +824,10 @@
     Func("a", ast::VariableList{}, ty.void_(),
          ast::StatementList{
              create<ast::VariableDeclStatement>(var),
-             create<ast::ReturnStatement>(),
+             Return(),
          },
-         ast::DecorationList{
-             create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+         {
+             Stage(ast::PipelineStage::kCompute),
          });
   }
 
@@ -843,10 +836,8 @@
                     MemberAccessor("data", "d"));
 
     Func("b", ast::VariableList{}, ty.void_(),
-         ast::StatementList{create<ast::VariableDeclStatement>(var),
-                            create<ast::ReturnStatement>()},
-         ast::DecorationList{
-             create<ast::StageDecoration>(ast::PipelineStage::kCompute)});
+         ast::StatementList{create<ast::VariableDeclStatement>(var), Return()},
+         {Stage(ast::PipelineStage::kCompute)});
   }
 
   GeneratorImpl& gen = Build();
diff --git a/test/compute_boids.wgsl b/test/compute_boids.wgsl
index 2f168dc..7570102 100644
--- a/test/compute_boids.wgsl
+++ b/test/compute_boids.wgsl
@@ -53,7 +53,7 @@
   particles : array<Particle, 5>;
 };
 
-[[binding(0), group(0)]] var<uniform> params : [[access(read)]] SimParams;
+[[binding(0), group(0)]] var<uniform> params : SimParams;
 [[binding(1), group(0)]] var<storage> particlesA : [[access(read_write)]] Particles;
 [[binding(2), group(0)]] var<storage> particlesB : [[access(read_write)]] Particles;
 
diff --git a/test/cube.wgsl b/test/cube.wgsl
index 313954a..6eefd59 100644
--- a/test/cube.wgsl
+++ b/test/cube.wgsl
@@ -17,7 +17,7 @@
   modelViewProjectionMatrix : mat4x4<f32>;
 };
 
-[[binding(0), group(0)]] var<uniform> uniforms : [[access(read)]] Uniforms;
+[[binding(0), group(0)]] var<uniform> uniforms : Uniforms;
 
 struct VertexInput {
   [[location(0)]] cur_position : vec4<f32>;