resolver: Validate uniform buffer types
Fixed: tint:210
Change-Id: I7763ca23a5dce09755a1ca71d35f24897a875ac0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48604
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/ast/module_clone_test.cc b/src/ast/module_clone_test.cc
index 6f39c42..6b61483 100644
--- a/src/ast/module_clone_test.cc
+++ b/src/ast/module_clone_test.cc
@@ -39,7 +39,7 @@
type t0 = [[stride(16)]] array<vec4<f32>>;
type t1 = array<vec4<f32>>;
-var<uniform> g0 : u32 = 20u;
+var<private> g0 : u32 = 20u;
var<private> g1 : f32 = 123.0;
var g2 : texture_2d<f32>;
var g3 : [[access(read)]] texture_storage_2d<r32uint>;
@@ -47,7 +47,7 @@
var g5 : [[access(read)]] texture_storage_2d<r32uint>;
var g6 : [[access(write)]] texture_storage_2d<rg32float>;
-[[builtin(position)]] var<uniform> g7 : vec3<f32>;
+var<private> g7 : vec3<f32>;
[[group(10), binding(20)]] var<storage> g8 : [[access(write)]] S;
[[group(10), binding(20)]] var<storage> g9 : [[access(read)]] S;
[[group(10), binding(20)]] var<storage> g10 : [[access(read_write)]] S;
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index a14baa1..eeb0ee7 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -1822,31 +1822,6 @@
EXPECT_TRUE(error.find("not an entry point") != std::string::npos);
}
-TEST_F(InspectorGetUniformBufferResourceBindingsTest, MissingBlockDeco) {
- ast::DecorationList decos;
- auto* str = create<ast::Struct>(
- Sym("foo_type"),
- ast::StructMemberList{Member(StructMemberName(0, ty.i32()), ty.i32())},
- decos);
-
- auto* foo_type = ty.struct_(str);
- AddUniformBuffer("foo_ub", foo_type, 0, 0);
-
- MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}});
-
- MakeCallerBodyFunction(
- "ep_func", {"ub_func"},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
- });
-
- Inspector& inspector = Build();
-
- auto result = inspector.GetUniformBufferResourceBindings("ep_func");
- ASSERT_FALSE(inspector.has_error()) << inspector.error();
- EXPECT_EQ(0u, result.size());
-}
-
TEST_F(InspectorGetUniformBufferResourceBindingsTest, Simple) {
sem::StructType* foo_struct_type =
MakeUniformBufferType("foo_type", {ty.i32()});
diff --git a/src/program_builder.h b/src/program_builder.h
index 2c61450..c006792 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -985,13 +985,42 @@
decorations);
}
- /// @param args the arguments to pass to Var()
- /// @returns a `ast::Variable` constructed by calling Var() with the arguments
- /// of `args`, which is automatically registered as a global variable with the
- /// ast::Module.
- template <typename... ARGS>
- ast::Variable* Global(ARGS&&... args) {
- auto* var = Var(std::forward<ARGS>(args)...);
+ /// @param name the variable name
+ /// @param type the variable type
+ /// @param storage the variable storage class
+ /// @param constructor constructor expression
+ /// @param decorations variable decorations
+ /// @returns a new `ast::Variable`, which is automatically registered as a
+ /// global variable with the ast::Module.
+ template <typename NAME>
+ ast::Variable* Global(NAME&& name,
+ sem::Type* type,
+ ast::StorageClass storage,
+ ast::Expression* constructor = nullptr,
+ ast::DecorationList decorations = {}) {
+ auto* var =
+ Var(std::forward<NAME>(name), type, storage, constructor, decorations);
+ AST().AddGlobalVariable(var);
+ return var;
+ }
+
+ /// @param source the variable source
+ /// @param name the variable name
+ /// @param type the variable type
+ /// @param storage the variable storage class
+ /// @param constructor constructor expression
+ /// @param decorations variable decorations
+ /// @returns a new `ast::Variable`, which is automatically registered as a
+ /// global variable with the ast::Module.
+ template <typename NAME>
+ ast::Variable* Global(const Source& source,
+ NAME&& name,
+ sem::Type* type,
+ ast::StorageClass storage,
+ ast::Expression* constructor = nullptr,
+ ast::DecorationList decorations = {}) {
+ auto* var = Var(source, std::forward<NAME>(name), type, storage,
+ constructor, decorations);
AST().AddGlobalVariable(var);
return var;
}
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 084fe7c..6532144 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -432,37 +432,69 @@
}
bool Resolver::ValidateGlobalVariable(const VariableInfo* info) {
- if (info->storage_class == ast::StorageClass::kStorage) {
- // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
- // Variables in the storage storage class and variables with a storage
- // texture type must have an access attribute applied to the store type.
+ switch (info->storage_class) {
+ case ast::StorageClass::kStorage: {
+ // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
+ // Variables in the storage storage class and variables with a storage
+ // texture type must have an access attribute applied to the store type.
- // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
- // A variable in the storage storage class is a storage buffer variable. Its
- // store type must be a host-shareable structure type with block attribute,
- // satisfying the storage class constraints.
+ // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+ // A variable in the storage storage class is a storage buffer variable.
+ // Its store type must be a host-shareable structure type with block
+ // attribute, satisfying the storage class constraints.
- auto* access = info->type->As<sem::AccessControl>();
- auto* str = access ? access->type()->As<sem::StructType>() : nullptr;
- if (!str) {
- diagnostics_.add_error(
- "variables declared in the <storage> storage class must be of an "
- "[[access]] qualified structure type",
- info->declaration->source());
- return false;
- }
-
- if (!str->IsBlockDecorated()) {
- diagnostics_.add_error(
- "structure used as a storage buffer must be declared with the "
- "[[block]] decoration",
- str->impl()->source());
- if (info->declaration->source().range.begin.line) {
- diagnostics_.add_note("structure used as storage buffer here",
- info->declaration->source());
+ auto* access = info->type->As<sem::AccessControl>();
+ auto* str = access ? access->type()->As<sem::StructType>() : nullptr;
+ if (!str) {
+ diagnostics_.add_error(
+ "variables declared in the <storage> storage class must be of an "
+ "[[access]] qualified structure type",
+ info->declaration->source());
+ return false;
}
- return false;
+
+ if (!str->IsBlockDecorated()) {
+ diagnostics_.add_error(
+ "structure used as a storage buffer must be declared with the "
+ "[[block]] decoration",
+ str->impl()->source());
+ if (info->declaration->source().range.begin.line) {
+ diagnostics_.add_note("structure used as storage buffer here",
+ info->declaration->source());
+ }
+ return false;
+ }
+ break;
}
+ case ast::StorageClass::kUniform: {
+ // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
+ // A variable in the uniform storage class is a uniform buffer variable.
+ // Its store type must be a host-shareable structure type with block
+ // attribute, satisfying the storage class constraints.
+ auto* str = info->type->As<sem::StructType>();
+ if (!str) {
+ diagnostics_.add_error(
+ "variables declared in the <uniform> storage class must be of a "
+ "structure type",
+ info->declaration->source());
+ return false;
+ }
+
+ if (!str->IsBlockDecorated()) {
+ diagnostics_.add_error(
+ "structure used as a uniform buffer must be declared with the "
+ "[[block]] decoration",
+ str->impl()->source());
+ if (info->declaration->source().range.begin.line) {
+ diagnostics_.add_note("structure used as uniform buffer here",
+ info->declaration->source());
+ }
+ return false;
+ }
+ break;
+ }
+ default:
+ break;
}
return ValidateVariable(info->declaration);
diff --git a/src/resolver/storage_class_validation_test.cc b/src/resolver/storage_class_validation_test.cc
index 931b23a..1ce3cbf 100644
--- a/src/resolver/storage_class_validation_test.cc
+++ b/src/resolver/storage_class_validation_test.cc
@@ -138,6 +138,94 @@
ASSERT_TRUE(r()->Resolve());
}
+///
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) {
+ // var<uniform> g : bool;
+ Global(Source{{56, 78}}, "g", ty.bool_(), ast::StorageClass::kUniform);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
+ // var<uniform> g : ptr<i32, input>;
+ Global(Source{{56, 78}}, "g", ty.pointer<i32>(ast::StorageClass::kInput),
+ ast::StorageClass::kUniform);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
+ // var<uniform> g : [[access(read)]] array<S, 3>;
+ auto* s = Structure("S", {Member("a", ty.f32())});
+ auto* a = ty.array(s, 3);
+ auto* ac = ty.access(ast::AccessControl::kReadOnly, a);
+ Global(Source{{56, 78}}, "g", ac, ast::StorageClass::kUniform);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) {
+ // type a = bool;
+ // var<uniform> g : [[access(read)]] a;
+ auto* a = ty.alias("a", ty.bool_());
+ Global(Source{{56, 78}}, "g", a, ast::StorageClass::kUniform);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoBlockDecoration) {
+ // struct S { x : i32 };
+ // var<uniform> g : S;
+ auto* s = Structure(Source{{12, 34}}, "S", {Member("x", ty.i32())});
+ Global(Source{{56, 78}}, "g", s, ast::StorageClass::kUniform);
+
+ ASSERT_FALSE(r()->Resolve());
+
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: structure used as a uniform buffer must be declared with the [[block]] decoration
+56:78 note: structure used as uniform buffer here)");
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Basic) {
+ // [[block]] struct S { x : i32 };
+ // var<uniform> g : S;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockDecoration>()});
+ Global(Source{{56, 78}}, "g", s, ast::StorageClass::kUniform);
+
+ ASSERT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverStorageClassValidationTest, UniformBufferNoError_Aliases) {
+ // [[block]] struct S { x : i32 };
+ // type a1 = S;
+ // var<uniform> g : a1;
+ auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
+ {create<ast::StructBlockDecoration>()});
+ auto* a1 = ty.alias("a1", s);
+ Global(Source{{56, 78}}, "g", a1, ast::StorageClass::kUniform);
+
+ ASSERT_TRUE(r()->Resolve());
+}
+
} // namespace
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/struct_storage_class_use_test.cc b/src/resolver/struct_storage_class_use_test.cc
index 4751992..a34a835 100644
--- a/src/resolver/struct_storage_class_use_test.cc
+++ b/src/resolver/struct_storage_class_use_test.cc
@@ -66,53 +66,53 @@
TEST_F(ResolverStorageClassUseTest, StructReachableFromGlobal) {
auto* s = Structure("S", {Member("a", ty.f32())});
- Global("g", s, ast::StorageClass::kUniform);
+ Global("g", s, ast::StorageClass::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
- UnorderedElementsAre(ast::StorageClass::kUniform));
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalAlias) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.alias("A", s);
- Global("g", a, ast::StorageClass::kUniform);
+ Global("g", a, ast::StorageClass::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
- UnorderedElementsAre(ast::StorageClass::kUniform));
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* o = Structure("O", {Member("a", s)});
- Global("g", o, ast::StorageClass::kUniform);
+ Global("g", o, ast::StorageClass::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
- UnorderedElementsAre(ast::StorageClass::kUniform));
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
}
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto* a = ty.array(s, 3);
- Global("g", a, ast::StorageClass::kUniform);
+ Global("g", a, ast::StorageClass::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(s);
ASSERT_NE(sem, nullptr);
EXPECT_THAT(sem->StorageClassUsage(),
- UnorderedElementsAre(ast::StorageClass::kUniform));
+ UnorderedElementsAre(ast::StorageClass::kPrivate));
}
TEST_F(ResolverStorageClassUseTest, StructReachableFromLocal) {
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index e3f733b..0c3cd2a 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -31,10 +31,9 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
Func("my_func", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{});
+ {
+ Return(),
+ });
GeneratorImpl& gen = Build();
@@ -50,10 +49,9 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
Func("GeometryShader", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{});
+ {
+ Return(),
+ });
GeneratorImpl& gen = SanitizeAndBuild();
@@ -68,10 +66,9 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
Func("my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())},
ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
- },
- ast::DecorationList{});
+ {
+ Return(),
+ });
GeneratorImpl& gen = Build();
@@ -87,10 +84,9 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_NoReturn_Void) {
- Func("main", ast::VariableList{}, ty.void_(),
- ast::StatementList{/* no explicit return */},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ Func("main", ast::VariableList{}, ty.void_(), {/* no explicit return */},
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -111,9 +107,8 @@
// return foo;
// }
auto* foo_in = Param("foo", ty.f32(), {create<ast::LocationDecoration>(0)});
- Func("frag_main", ast::VariableList{foo_in}, ty.f32(),
- {create<ast::ReturnStatement>(Expr("foo"))},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return(Expr("foo"))},
+ {Stage(ast::PipelineStage::kFragment)},
{create<ast::LocationDecoration>(1)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -146,8 +141,8 @@
Param("coord", ty.vec4<f32>(),
{create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
- {create<ast::ReturnStatement>(MemberAccessor("coord", "x"))},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ {Return(MemberAccessor("coord", "x"))},
+ {Stage(ast::PipelineStage::kFragment)},
{create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -201,14 +196,12 @@
Func("frag_main", {Param("inputs", interface_struct)}, ty.void_(),
{
- WrapInStatement(
- Const("r", ty.f32(), MemberAccessor(Expr("inputs"), "col1"))),
- WrapInStatement(
- Const("g", ty.f32(), MemberAccessor(Expr("inputs"), "col2"))),
- WrapInStatement(Const("p", ty.vec4<f32>(),
- MemberAccessor(Expr("inputs"), "pos"))),
+ Decl(Const("r", ty.f32(), MemberAccessor(Expr("inputs"), "col1"))),
+ Decl(Const("g", ty.f32(), MemberAccessor(Expr("inputs"), "col2"))),
+ Decl(Const("p", ty.vec4<f32>(),
+ MemberAccessor(Expr("inputs"), "pos"))),
},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+ {Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -268,20 +261,19 @@
{create<ast::BuiltinDecoration>(ast::Builtin::kPosition)})});
Func("foo", {Param("x", ty.f32())}, vertex_output_struct,
- {create<ast::ReturnStatement>(Construct(
- vertex_output_struct, Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
- Expr("x"), Expr(1.f))))},
+ {Return(Construct(vertex_output_struct,
+ Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
+ Expr("x"), Expr(1.f))))},
{});
Func("vert_main1", {}, vertex_output_struct,
- {create<ast::ReturnStatement>(
- Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
- {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+ {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
+ {Stage(ast::PipelineStage::kVertex)});
- Func("vert_main2", {}, vertex_output_struct,
- {create<ast::ReturnStatement>(
- Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
- {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+ Func(
+ "vert_main2", {}, vertex_output_struct,
+ {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
+ {Stage(ast::PipelineStage::kVertex)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -320,33 +312,48 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_With_Uniform) {
- Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
- ast::DecorationList{
- create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1),
- });
+ auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
+ {create<ast::StructBlockDecoration>()});
+ auto* ubo = Global(
+ "ubo", ubo_ty, ast::StorageClass::kUniform, nullptr,
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
- auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
- MemberAccessor("coord", "x"));
-
- Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Func("sub_func",
+ {
+ Param("param", ty.f32()),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ ty.f32(),
+ {
+ Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
+ });
+
+ auto* var =
+ Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
+
+ Func("frag_main", {}, ty.void_(),
+ {
+ Decl(var),
+ Return(),
+ },
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
+ EXPECT_EQ(result(), R"(struct UBO {
float4 coord;
};
+ConstantBuffer<UBO> ubo : register(b0, space1);
+
+float sub_func(float param) {
+ return ubo.coord.x;
+}
+
void frag_main() {
- float v = coord.x;
+ float v = sub_func(1.0f);
return;
}
@@ -357,10 +364,11 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_With_UniformStruct) {
- auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())});
+ auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())},
+ {create<ast::StructBlockDecoration>()});
Global("uniforms", s, ast::StorageClass::kUniform, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
@@ -370,12 +378,12 @@
MemberAccessor("uniforms", "coord"), Expr("x")));
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -409,7 +417,7 @@
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
@@ -418,12 +426,12 @@
MemberAccessor("coord", "b"));
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -455,7 +463,7 @@
sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
@@ -464,12 +472,12 @@
MemberAccessor("coord", "b"));
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -501,19 +509,19 @@
sem::AccessControl ac(ast::AccessControl::kWriteOnly, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(MemberAccessor("coord", "b"),
Expr(2.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -545,19 +553,19 @@
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(MemberAccessor("coord", "b"),
Expr(2.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -582,36 +590,35 @@
HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT
Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
- ast::DecorationList{
+ {
create<ast::LocationDecoration>(0),
});
Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
+ {
create<ast::LocationDecoration>(1),
});
Global("val", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
+ {
create<ast::LocationDecoration>(0),
});
Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
create<ast::AssignmentStatement>(Expr("val"), Expr("param")),
- create<ast::ReturnStatement>(Expr("foo")),
- },
- ast::DecorationList{});
+ Return(Expr("foo")),
+ });
Func(
"ep_1", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("bar"), Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -646,24 +653,23 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) {
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
+ {
create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
});
Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
- ast::StatementList{
- create<ast::ReturnStatement>(Expr("param")),
- },
- ast::DecorationList{});
+ {
+ Return(Expr("param")),
+ });
Func("ep_1", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("depth"),
Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -693,31 +699,30 @@
HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT
Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
- ast::DecorationList{
+ {
create<ast::BuiltinDecoration>(ast::Builtin::kPosition),
});
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
+ {
create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth),
});
Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("depth"),
MemberAccessor("coord", "x")),
- create<ast::ReturnStatement>(Expr("param")),
- },
- ast::DecorationList{});
+ Return(Expr("param")),
+ });
Func("ep_1", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("depth"),
Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -749,37 +754,40 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
- Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
- ast::DecorationList{
+ auto* s = Structure("S", {Member("x", ty.f32())},
+ {create<ast::StructBlockDecoration>()});
+ Global("coord", s, ast::StorageClass::kUniform, nullptr,
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
- ast::StatementList{
- create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
- },
- ast::DecorationList{});
+ {
+ Return(MemberAccessor("coord", "x")),
+ });
auto* var =
Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
- float4 coord;
+ EXPECT_EQ(result(), R"(struct S {
+ float x;
};
+ConstantBuffer<S> coord : register(b0, space1);
+
float sub_func(float param) {
return coord.x;
}
@@ -800,27 +808,26 @@
{create<ast::StructBlockDecoration>()});
auto* ac = ty.access(ast::AccessControl::kReadWrite, s);
Global("coord", ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
- ast::StatementList{
- create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
- },
- ast::DecorationList{});
+ {
+ Return(MemberAccessor("coord", "x")),
+ });
auto* var =
Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -847,25 +854,21 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) {
Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
+ {
create<ast::LocationDecoration>(1),
});
- auto* list = create<ast::BlockStatement>(ast::StatementList{
- create<ast::ReturnStatement>(),
- });
-
Func(
"ep_1", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::AssignmentStatement>(Expr("bar"), Expr(1.0f)),
create<ast::IfStatement>(create<ast::BinaryExpression>(
ast::BinaryOp::kEqual, Expr(1), Expr(1)),
- list, ast::ElseStatementList{}),
- create<ast::ReturnStatement>(),
+ Block(Return()), ast::ElseStatementList{}),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -891,9 +894,9 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_WithNameCollision) {
- Func("GeometryShader", ast::VariableList{}, ty.void_(), ast::StatementList{},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ Func("GeometryShader", ast::VariableList{}, ty.void_(), {},
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -908,11 +911,11 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_Compute) {
Func("main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
+ {
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ {
+ Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
@@ -931,11 +934,11 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Decoration_EntryPoint_Compute_WithWorkgroup) {
Func("main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
+ {
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ {
+ Stage(ast::PipelineStage::kCompute),
create<ast::WorkgroupDecoration>(2u, 4u, 6u),
});
@@ -954,8 +957,8 @@
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
- ast::StatementList{
- create<ast::ReturnStatement>(),
+ {
+ Return(),
});
GeneratorImpl& gen = Build();
@@ -990,14 +993,13 @@
// return;
// }
- auto* s =
- Structure("Data", {Member("d", ty.f32())},
- ast::DecorationList{create<ast::StructBlockDecoration>()});
+ auto* s = Structure("Data", {Member("d", ty.f32())},
+ {create<ast::StructBlockDecoration>()});
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("data", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{
+ {
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
@@ -1007,12 +1009,12 @@
MemberAccessor("data", "d"));
Func("a", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ {
+ Stage(ast::PipelineStage::kCompute),
});
}
@@ -1021,12 +1023,12 @@
MemberAccessor("data", "d"));
Func("b", ast::VariableList{}, ty.void_(),
- ast::StatementList{
+ {
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ {
+ Stage(ast::PipelineStage::kCompute),
});
}
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index 473d321..e5c70a5 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -28,9 +28,9 @@
TEST_F(MslGeneratorImplTest, Emit_Function) {
Func("my_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{});
+ {});
GeneratorImpl& gen = Build();
@@ -54,9 +54,9 @@
Func("my_func", params, ty.void_(),
ast::StatementList{
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{});
+ {});
GeneratorImpl& gen = Build();
@@ -76,8 +76,7 @@
TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_NoReturn_Void) {
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{/* no explicit return */},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+ {Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();
@@ -97,9 +96,8 @@
// return foo;
// }
auto* foo_in = Param("foo", ty.f32(), {create<ast::LocationDecoration>(0)});
- Func("frag_main", ast::VariableList{foo_in}, ty.f32(),
- {create<ast::ReturnStatement>(Expr("foo"))},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return(Expr("foo"))},
+ {Stage(ast::PipelineStage::kFragment)},
{create<ast::LocationDecoration>(1)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -133,8 +131,8 @@
Param("coord", ty.vec4<f32>(),
{create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
- {create<ast::ReturnStatement>(MemberAccessor("coord", "x"))},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)},
+ {Return(MemberAccessor("coord", "x"))},
+ {Stage(ast::PipelineStage::kFragment)},
{create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -194,7 +192,7 @@
WrapInStatement(
Const("g", ty.f32(), MemberAccessor(Expr("colors"), "col2"))),
},
- {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
+ {Stage(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -255,20 +253,19 @@
{create<ast::BuiltinDecoration>(ast::Builtin::kPosition)})});
Func("foo", {Param("x", ty.f32())}, vertex_output_struct,
- {create<ast::ReturnStatement>(Construct(
- vertex_output_struct, Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
- Expr("x"), Expr(1.f))))},
+ {Return(Construct(vertex_output_struct,
+ Construct(ty.vec4<f32>(), Expr("x"), Expr("x"),
+ Expr("x"), Expr(1.f))))},
{});
Func("vert_main1", {}, vertex_output_struct,
- {create<ast::ReturnStatement>(
- Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
- {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+ {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))},
+ {Stage(ast::PipelineStage::kVertex)});
- Func("vert_main2", {}, vertex_output_struct,
- {create<ast::ReturnStatement>(
- Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
- {create<ast::StageDecoration>(ast::PipelineStage::kVertex)});
+ Func(
+ "vert_main2", {}, vertex_output_struct,
+ {Return(Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))},
+ {Stage(ast::PipelineStage::kVertex)});
GeneratorImpl& gen = SanitizeAndBuild();
@@ -317,8 +314,7 @@
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1)});
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
MemberAccessor("coord", "b"));
@@ -326,10 +322,10 @@
Func("frag_main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -363,8 +359,7 @@
sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1)});
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
MemberAccessor("coord", "b"));
@@ -372,10 +367,10 @@
Func("frag_main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -402,13 +397,13 @@
MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT
Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
- ast::DecorationList{create<ast::LocationDecoration>(0)});
+ {create<ast::LocationDecoration>(0)});
Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{create<ast::LocationDecoration>(1)});
+ {create<ast::LocationDecoration>(1)});
Global("val", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{create<ast::LocationDecoration>(0)});
+ {create<ast::LocationDecoration>(0)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
@@ -416,18 +411,18 @@
auto body = ast::StatementList{
create<ast::AssignmentStatement>(Expr("bar"), Expr("foo")),
create<ast::AssignmentStatement>(Expr("val"), Expr("param")),
- create<ast::ReturnStatement>(Expr("foo"))};
+ Return(Expr("foo"))};
- Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+ Func("sub_func", params, ty.f32(), body, {});
body = ast::StatementList{
create<ast::AssignmentStatement>(Expr("bar"), Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -464,26 +459,25 @@
TEST_F(MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) {
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
+ {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
Func("sub_func", params, ty.f32(),
ast::StatementList{
- create<ast::ReturnStatement>(Expr("param")),
+ Return(Expr("param")),
},
- ast::DecorationList{});
+ {});
auto body = ast::StatementList{
create<ast::AssignmentStatement>(Expr("depth"), Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -514,12 +508,10 @@
MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT
Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
+ {create<ast::BuiltinDecoration>(ast::Builtin::kPosition)});
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{
- create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
+ {create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
@@ -527,19 +519,19 @@
auto body = ast::StatementList{
create<ast::AssignmentStatement>(Expr("depth"),
MemberAccessor("coord", "x")),
- create<ast::ReturnStatement>(Expr("param")),
+ Return(Expr("param")),
};
- Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+ Func("sub_func", params, ty.f32(), body, {});
body = ast::StatementList{
create<ast::AssignmentStatement>(Expr("depth"), Call("sub_func", 1.0f)),
- create<ast::ReturnStatement>(),
+ Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -568,29 +560,31 @@
TEST_F(MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
- Global("coord", ty.vec4<f32>(), ast::StorageClass::kUniform, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1)});
+ auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
+ {create<ast::StructBlockDecoration>()});
+ auto* ubo = Global(
+ "ubo", ubo_ty, ast::StorageClass::kUniform, nullptr,
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
- ast::VariableList params;
- params.push_back(Param("param", ty.f32()));
-
- auto body = ast::StatementList{
- create<ast::ReturnStatement>(MemberAccessor("coord", "x")),
- };
-
- Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+ Func("sub_func",
+ {
+ Param("param", ty.f32()),
+ },
+ ty.f32(),
+ {
+ Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
+ });
auto* var =
Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
- Func("frag_main", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Func("frag_main", {}, ty.void_(),
+ {
+ Decl(var),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -599,12 +593,16 @@
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
-float sub_func(constant float4& coord, float param) {
- return coord.x;
+struct UBO {
+ /* 0x0000 */ packed_float4 coord;
+};
+
+float sub_func(constant UBO& ubo, float param) {
+ return ubo.coord.x;
}
-fragment void frag_main(constant float4& coord [[buffer(0)]]) {
- float v = sub_func(coord, 1.0f);
+fragment void frag_main(constant UBO& ubo [[buffer(0)]]) {
+ float v = sub_func(ubo, 1.0f);
return;
}
@@ -623,16 +621,14 @@
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1)});
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
- auto body = ast::StatementList{
- create<ast::ReturnStatement>(MemberAccessor("coord", "b"))};
+ auto body = ast::StatementList{Return(MemberAccessor("coord", "b"))};
- Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+ Func("sub_func", params, ty.f32(), body, {});
auto* var =
Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
@@ -640,10 +636,10 @@
Func("frag_main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -681,16 +677,14 @@
sem::AccessControl ac(ast::AccessControl::kReadOnly, s);
Global("coord", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(1)});
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(1)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
- auto body = ast::StatementList{
- create<ast::ReturnStatement>(MemberAccessor("coord", "b"))};
+ auto body = ast::StatementList{Return(MemberAccessor("coord", "b"))};
- Func("sub_func", params, ty.f32(), body, ast::DecorationList{});
+ Func("sub_func", params, ty.f32(), body, {});
auto* var =
Var("v", ty.f32(), ast::StorageClass::kFunction, Call("sub_func", 1.0f));
@@ -698,10 +692,10 @@
Func("frag_main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -731,10 +725,10 @@
TEST_F(MslGeneratorImplTest,
Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) {
Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr,
- ast::DecorationList{create<ast::LocationDecoration>(1)});
+ {create<ast::LocationDecoration>(1)});
auto* list = create<ast::BlockStatement>(ast::StatementList{
- create<ast::ReturnStatement>(),
+ Return(),
});
auto body = ast::StatementList{
@@ -742,12 +736,12 @@
create<ast::IfStatement>(create<ast::BinaryExpression>(
ast::BinaryOp::kEqual, Expr(1), Expr(1)),
list, ast::ElseStatementList{}),
- create<ast::ReturnStatement>(),
+ Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ {
+ Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
@@ -778,9 +772,9 @@
Func("my_func", params, ty.void_(),
ast::StatementList{
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{});
+ {});
GeneratorImpl& gen = Build();
@@ -821,8 +815,7 @@
sem::AccessControl ac(ast::AccessControl::kReadWrite, s);
Global("data", &ac, ast::StorageClass::kStorage, nullptr,
- ast::DecorationList{create<ast::BindingDecoration>(0),
- create<ast::GroupDecoration>(0)});
+ {create<ast::BindingDecoration>(0), create<ast::GroupDecoration>(0)});
{
auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction,
@@ -831,10 +824,10 @@
Func("a", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>(),
+ Return(),
},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+ {
+ Stage(ast::PipelineStage::kCompute),
});
}
@@ -843,10 +836,8 @@
MemberAccessor("data", "d"));
Func("b", ast::VariableList{}, ty.void_(),
- ast::StatementList{create<ast::VariableDeclStatement>(var),
- create<ast::ReturnStatement>()},
- ast::DecorationList{
- create<ast::StageDecoration>(ast::PipelineStage::kCompute)});
+ ast::StatementList{create<ast::VariableDeclStatement>(var), Return()},
+ {Stage(ast::PipelineStage::kCompute)});
}
GeneratorImpl& gen = Build();
diff --git a/test/compute_boids.wgsl b/test/compute_boids.wgsl
index 2f168dc..7570102 100644
--- a/test/compute_boids.wgsl
+++ b/test/compute_boids.wgsl
@@ -53,7 +53,7 @@
particles : array<Particle, 5>;
};
-[[binding(0), group(0)]] var<uniform> params : [[access(read)]] SimParams;
+[[binding(0), group(0)]] var<uniform> params : SimParams;
[[binding(1), group(0)]] var<storage> particlesA : [[access(read_write)]] Particles;
[[binding(2), group(0)]] var<storage> particlesB : [[access(read_write)]] Particles;
diff --git a/test/cube.wgsl b/test/cube.wgsl
index 313954a..6eefd59 100644
--- a/test/cube.wgsl
+++ b/test/cube.wgsl
@@ -17,7 +17,7 @@
modelViewProjectionMatrix : mat4x4<f32>;
};
-[[binding(0), group(0)]] var<uniform> uniforms : [[access(read)]] Uniforms;
+[[binding(0), group(0)]] var<uniform> uniforms : Uniforms;
struct VertexInput {
[[location(0)]] cur_position : vec4<f32>;