Import Tint changes from Dawn
Changes:
- 0e22bdbae79384907d7044572dc6c07831929c2e tint/msl: Fix emission of private variables by James Price <jrprice@google.com>
- 6af073cecc7cd68d39beaefadf63eb3bec54bb98 Tint&Dawn: Enable f16 override by Zhaoming Jiang <zhaoming.jiang@intel.com>
GitOrigin-RevId: 0e22bdbae79384907d7044572dc6c07831929c2e
Change-Id: I27827f2b07e40e764fcd42fd5d6c7c7b2e5493ac
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/124841
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/cmd/helper.cc b/src/tint/cmd/helper.cc
index 49b970a..6ff1e7a 100644
--- a/src/tint/cmd/helper.cc
+++ b/src/tint/cmd/helper.cc
@@ -505,6 +505,8 @@
return "bool";
case tint::inspector::Override::Type::kFloat32:
return "f32";
+ case tint::inspector::Override::Type::kFloat16:
+ return "f16";
case tint::inspector::Override::Type::kUint32:
return "u32";
case tint::inspector::Override::Type::kInt32:
diff --git a/src/tint/inspector/entry_point.h b/src/tint/inspector/entry_point.h
index fd17ba0..8fd003c 100644
--- a/src/tint/inspector/entry_point.h
+++ b/src/tint/inspector/entry_point.h
@@ -92,6 +92,7 @@
kFloat32,
kUint32,
kInt32,
+ kFloat16,
};
/// Type of the scalar
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index 93e5c58..0b8b22c 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -204,7 +204,11 @@
if (type->is_bool_scalar_or_vector()) {
override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) {
- override.type = Override::Type::kFloat32;
+ if (type->Is<type::F16>()) {
+ override.type = Override::Type::kFloat16;
+ } else {
+ override.type = Override::Type::kFloat32;
+ }
} else if (type->is_signed_integer_scalar()) {
override.type = Override::Type::kInt32;
} else if (type->is_unsigned_integer_scalar()) {
@@ -270,6 +274,10 @@
[&](const type::I32*) { return Scalar(value->ValueAs<i32>()); },
[&](const type::U32*) { return Scalar(value->ValueAs<u32>()); },
[&](const type::F32*) { return Scalar(value->ValueAs<f32>()); },
+ [&](const type::F16*) {
+ // Default value of f16 override is also stored as float scalar.
+ return Scalar(static_cast<float>(value->ValueAs<f16>()));
+ },
[&](const type::Bool*) { return Scalar(value->ValueAs<bool>()); });
continue;
}
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index b359074..de09369 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -908,18 +908,23 @@
}
TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
+ Enable(builtin::Extension::kF16);
+
Override("bool_var", ty.bool_());
Override("float_var", ty.f32());
Override("u32_var", ty.u32());
Override("i32_var", ty.i32());
+ Override("f16_var", ty.f16());
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("u32_func", "u32_var", ty.u32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("i32_func", "i32_var", ty.i32(), utils::Empty);
+ MakePlainGlobalReferenceBodyFunction("f16_func", "f16_var", ty.f16(), utils::Empty);
MakeCallerBodyFunction(
- "ep_func", utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func"},
+ "ep_func",
+ utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func", "f16_func"},
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1_i),
@@ -930,7 +935,7 @@
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
- ASSERT_EQ(4u, result[0].overrides.size());
+ ASSERT_EQ(5u, result[0].overrides.size());
EXPECT_EQ("bool_var", result[0].overrides[0].name);
EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
EXPECT_EQ("float_var", result[0].overrides[1].name);
@@ -939,6 +944,8 @@
EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
EXPECT_EQ("i32_var", result[0].overrides[3].name);
EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
+ EXPECT_EQ("f16_var", result[0].overrides[4].name);
+ EXPECT_EQ(inspector::Override::Type::kFloat16, result[0].overrides[4].type);
}
TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
@@ -1572,7 +1579,7 @@
EXPECT_EQ(100, result[OverrideId{6000}].AsI32());
}
-TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
+TEST_F(InspectorGetOverrideDefaultValuesTest, F32) {
Override("a", ty.f32(), Id(1_a));
Override("b", ty.f32(), Expr(0_f), Id(20_a));
Override("c", ty.f32(), Expr(-10_f), Id(300_a));
@@ -1609,6 +1616,46 @@
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
}
+TEST_F(InspectorGetOverrideDefaultValuesTest, F16) {
+ Enable(builtin::Extension::kF16);
+
+ Override("a", ty.f16(), Id(1_a));
+ Override("b", ty.f16(), Expr(0_h), Id(20_a));
+ Override("c", ty.f16(), Expr(-10_h), Id(300_a));
+ Override("d", Expr(15_h), Id(4000_a));
+ Override("3", Expr(42.0_h), Id(5000_a));
+ Override("e", ty.f16(), Mul(15_h, 10_a), Id(6000_a));
+
+ Inspector& inspector = Build();
+
+ auto result = inspector.GetOverrideDefaultValues();
+ ASSERT_EQ(6u, result.size());
+
+ ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
+ EXPECT_TRUE(result[OverrideId{1}].IsNull());
+
+ ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
+ // Default value of f16 override is also stored as float scalar.
+ EXPECT_TRUE(result[OverrideId{20}].IsFloat());
+ EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat());
+
+ ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
+ EXPECT_TRUE(result[OverrideId{300}].IsFloat());
+ EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat());
+
+ ASSERT_TRUE(result.find(OverrideId{4000}) != result.end());
+ EXPECT_TRUE(result[OverrideId{4000}].IsFloat());
+ EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat());
+
+ ASSERT_TRUE(result.find(OverrideId{5000}) != result.end());
+ EXPECT_TRUE(result[OverrideId{5000}].IsFloat());
+ EXPECT_FLOAT_EQ(42.0f, result[OverrideId{5000}].AsFloat());
+
+ ASSERT_TRUE(result.find(OverrideId{6000}) != result.end());
+ EXPECT_TRUE(result[OverrideId{6000}].IsFloat());
+ EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
+}
+
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
Override("v1", ty.f32(), Id(1_a));
Override("v20", ty.f32(), Id(20_a));
diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc
index fd0f649..c825969 100644
--- a/src/tint/resolver/override_test.cc
+++ b/src/tint/resolver/override_test.cc
@@ -66,10 +66,12 @@
}
TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
+ Enable(builtin::Extension::kF16);
+
auto* a = Override("a", ty.f32(), Expr(1_f));
- auto* b = Override("b", ty.f32(), Expr(1_f));
- auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u));
- auto* d = Override("d", ty.f32(), Expr(1_f), Id(4_u));
+ auto* b = Override("b", ty.f16(), Expr(1_h));
+ auto* c = Override("c", ty.i32(), Expr(1_i), Id(2_u));
+ auto* d = Override("d", ty.u32(), Expr(1_u), Id(4_u));
auto* e = Override("e", ty.f32(), Expr(1_f));
auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u));
@@ -102,16 +104,6 @@
EXPECT_EQ(r()->error(), "12:34 error: @id value must be between 0 and 65535");
}
-TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
- Enable(builtin::Extension::kF16);
-
- Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1_u));
-
- EXPECT_FALSE(r()->Resolve());
-
- EXPECT_EQ(r()->error(), "12:34 error: 'override' of type f16 is not implemented yet");
-}
-
TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) {
auto* a = Override("a", ty.f32());
auto* b = Override("b", ty.f32(), Expr(1_f));
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index c3b8a3d..b005779 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -788,11 +788,6 @@
return false;
}
- if (storage_ty->Is<type::F16>()) {
- AddError("'override' of type f16 is not implemented yet", decl->source);
- return false;
- }
-
return true;
}
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc
index 044133c..e82b1bd 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc
@@ -33,7 +33,7 @@
namespace tint::transform {
namespace {
-using WorkgroupParameterMemberList = utils::Vector<const ast::StructMember*, 8>;
+using StructMemberList = utils::Vector<const ast::StructMember*, 8>;
// The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr";
@@ -114,7 +114,7 @@
const sem::Variable* var,
Symbol new_var_symbol,
std::function<Symbol()> workgroup_param,
- WorkgroupParameterMemberList& workgroup_parameter_members,
+ StructMemberList& workgroup_parameter_members,
bool& is_pointer,
bool& is_wrapped) {
auto* ty = var->Type()->UnwrapRef();
@@ -188,21 +188,14 @@
member_ptr);
ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
is_pointer = true;
-
- break;
+ } else {
+ auto* disable_validation =
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace);
+ auto* initializer = ctx.Clone(var->Declaration()->initializer);
+ auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
+ utils::Vector{disable_validation});
+ ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
}
- [[fallthrough]];
- }
- case builtin::AddressSpace::kPrivate: {
- // Variables in the Private and Workgroup address spaces are redeclared at function
- // scope. Disable address space validation on this variable.
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace);
- auto* initializer = ctx.Clone(var->Declaration()->initializer);
- auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
- utils::Vector{disable_validation});
- ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
-
break;
}
case builtin::AddressSpace::kPushConstant: {
@@ -234,6 +227,8 @@
auto sc = var->AddressSpace();
switch (sc) {
case builtin::AddressSpace::kPrivate:
+ // Private variables are passed all together in a struct.
+ return;
case builtin::AddressSpace::kStorage:
case builtin::AddressSpace::kUniform:
case builtin::AddressSpace::kHandle:
@@ -275,12 +270,12 @@
/// @param var the variable to replace
/// @param new_var the symbol to use for replacement
/// @param is_pointer true if `new_var` is a pointer to the new variable
- /// @param is_wrapped true if `new_var` is an array wrapped in a structure
+ /// @param member_name if valid, the name of the struct member that holds this variable
void ReplaceUsesInFunction(const ast::Function* func,
const sem::Variable* var,
Symbol new_var,
bool is_pointer,
- bool is_wrapped) {
+ Symbol member_name) {
for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func) {
const ast::Expression* expr = ctx.dst->Expr(new_var);
@@ -288,16 +283,16 @@
// If this identifier is used by an address-of operator, just remove the
// address-of instead of adding a deref, since we already have a pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
- if (ident_to_address_of_.count(ident)) {
+ if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
ctx.Replace(ident_to_address_of_[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
- if (is_wrapped) {
- // Get the member from the wrapper structure.
- expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
+ if (member_name.IsValid()) {
+ // Get the member from the containing structure.
+ expr = ctx.dst->MemberAccessor(expr, member_name);
}
ctx.Replace(user->Declaration(), expr);
}
@@ -312,8 +307,34 @@
utils::Vector<const ast::Function*, 8> functions_to_process;
+ // Collect private variables into a single structure.
+ StructMemberList private_struct_members;
+ utils::Vector<std::function<const ast::AssignmentStatement*()>, 4> private_initializers;
+ std::unordered_set<const ast::Function*> uses_privates;
+
// Build a list of functions that transitively reference any module-scope variables.
for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
+ if (auto* var = decl->As<ast::Var>()) {
+ auto* sem_var = ctx.src->Sem().Get(var);
+ if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ // Create a member in the private variable struct.
+ auto* ty = sem_var->Type()->UnwrapRef();
+ auto name = ctx.Clone(var->name->symbol);
+ private_struct_members.Push(ctx.dst->Member(name, CreateASTTypeFor(ctx, ty)));
+ CloneStructTypes(ty);
+
+ // Create a statement to assign the initializer if present.
+ if (var->initializer) {
+ private_initializers.Push([&, name, var]() {
+ return ctx.dst->Assign(
+ ctx.dst->MemberAccessor(PrivateStructVariableName(), name),
+ ctx.Clone(var->initializer));
+ });
+ }
+ }
+ continue;
+ }
+
auto* func_ast = decl->As<ast::Function>();
if (!func_ast) {
continue;
@@ -324,8 +345,10 @@
bool needs_processing = false;
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
if (var->AddressSpace() != builtin::AddressSpace::kUndefined) {
+ if (var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ uses_privates.insert(func_ast);
+ }
needs_processing = true;
- break;
}
}
if (needs_processing) {
@@ -339,6 +362,14 @@
}
}
+ if (!private_struct_members.IsEmpty()) {
+ // Create the private variable struct.
+ ctx.dst->Structure(PrivateStructName(), std::move(private_struct_members));
+ // Passing a pointer to a private variable will now involve passing a pointer to the
+ // member of a structure, so enable the extension that allows this.
+ ctx.dst->Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
+ }
+
// Build a list of `&ident` expressions. We'll use this later to avoid generating
// expressions of the form `&*ident`, which break WGSL validation rules when this expression
// is passed to a function.
@@ -370,7 +401,7 @@
// We aggregate all workgroup variables into a struct to avoid hitting MSL's limit for
// threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
- WorkgroupParameterMemberList workgroup_parameter_members;
+ StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
@@ -378,12 +409,43 @@
return workgroup_parameter_symbol;
};
+ // If this function references any private variables, it needs to take the private
+ // variable struct as a parameter (or declare it, if it is an entry point function).
+ if (uses_privates.count(func_ast)) {
+ if (is_entry_point) {
+ // Create a local declaration for the private variable struct.
+ auto* var = ctx.dst->Var(
+ PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
+ builtin::AddressSpace::kPrivate,
+ utils::Vector{
+ ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace),
+ });
+ ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
+
+ // Initialize the members of that struct with the original initializers.
+ for (auto init : private_initializers) {
+ ctx.InsertFront(func_ast->body->statements, init());
+ }
+ } else {
+ // Create a parameter that is a pointer to the private variable struct.
+ auto ptr = ctx.dst->ty.pointer(ctx.dst->ty(PrivateStructName()),
+ builtin::AddressSpace::kPrivate);
+ auto* param = ctx.dst->Param(PrivateStructVariableName(), ptr);
+ ctx.InsertBack(func_ast->params, param);
+ }
+ }
+
// Process and redeclare all variables referenced by the function.
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
if (var->AddressSpace() == builtin::AddressSpace::kUndefined) {
continue;
}
- if (local_private_vars_.count(var)) {
+ if (var->AddressSpace() == builtin::AddressSpace::kPrivate) {
+ // Private variable are collected into a single struct that is passed by
+ // pointer (handled above), so we just need to replace the uses here.
+ ReplaceUsesInFunction(func_ast, var, PrivateStructVariableName(),
+ /* is_pointer */ !is_entry_point,
+ ctx.Clone(var->Declaration()->name->symbol));
continue;
}
@@ -396,49 +458,25 @@
// Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false;
- // Check if this is a private variable that is only referenced by this function.
- bool local_private = false;
- if (var->AddressSpace() == builtin::AddressSpace::kPrivate) {
- local_private = true;
- for (auto* user : var->Users()) {
- auto* stmt = user->Stmt();
- if (!stmt || stmt->Function() != func_sem) {
- local_private = false;
- break;
- }
- }
- }
-
- if (local_private) {
- // Redeclare the variable at function scope.
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace);
- auto* initializer = ctx.Clone(var->Declaration()->initializer);
- auto* local_var = ctx.dst->Var(new_var_symbol,
- CreateASTTypeFor(ctx, var->Type()->UnwrapRef()),
- builtin::AddressSpace::kPrivate, initializer,
- utils::Vector{disable_validation});
- ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var));
- local_private_vars_.insert(var);
+ // Process the variable to redeclare it as a parameter or local variable.
+ if (is_entry_point) {
+ ProcessVariableInEntryPoint(func_ast, var, new_var_symbol, workgroup_param,
+ workgroup_parameter_members, is_pointer,
+ is_wrapped);
} else {
- // Process the variable to redeclare it as a parameter or local variable.
- if (is_entry_point) {
- ProcessVariableInEntryPoint(func_ast, var, new_var_symbol, workgroup_param,
- workgroup_parameter_members, is_pointer,
- is_wrapped);
- } else {
- ProcessVariableInUserFunction(func_ast, var, new_var_symbol, is_pointer);
- if (var->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
- needs_pointer_aliasing = true;
- }
+ ProcessVariableInUserFunction(func_ast, var, new_var_symbol, is_pointer);
+ if (var->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
+ needs_pointer_aliasing = true;
}
-
- // Record the replacement symbol.
- var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
}
+ // Record the replacement symbol.
+ var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
+
// Replace all uses of the module-scope variable.
- ReplaceUsesInFunction(func_ast, var, new_var_symbol, is_pointer, is_wrapped);
+ ReplaceUsesInFunction(
+ func_ast, var, new_var_symbol, is_pointer,
+ is_wrapped ? ctx.dst->Sym(kWrappedArrayMemberName) : Symbol());
}
// Allow pointer aliasing if needed.
@@ -468,6 +506,15 @@
auto* call_sem = ctx.src->Sem().Get(call)->Unwrap()->As<sem::Call>();
auto* target_sem = call_sem->Target()->As<sem::Function>();
+ // Pass the private variable struct pointer if needed.
+ if (uses_privates.count(target_sem->Declaration())) {
+ const ast::Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
+ if (is_entry_point) {
+ arg = ctx.dst->AddressOf(arg);
+ }
+ ctx.InsertBack(call->args, arg);
+ }
+
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
@@ -509,16 +556,35 @@
}
}
+ /// @returns the name of the structure that contains all of the module-scope private variables
+ Symbol PrivateStructName() {
+ if (!private_struct_name.IsValid()) {
+ private_struct_name = ctx.dst->Symbols().New("tint_private_vars_struct");
+ }
+ return private_struct_name;
+ }
+
+ /// @returns the name of the variable that contains all of the module-scope private variables
+ Symbol PrivateStructVariableName() {
+ if (!private_struct_variable_name.IsValid()) {
+ private_struct_variable_name = ctx.dst->Symbols().New("tint_private_vars");
+ }
+ return private_struct_variable_name;
+ }
+
private:
// The structures that have already been cloned by this transform.
std::unordered_set<const sem::Struct*> cloned_structs_;
- // Set of a private variables that are local to a single function.
- std::unordered_set<const sem::Variable*> local_private_vars_;
-
// Map from identifier expression to the address-of expression that uses it.
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
ident_to_address_of_;
+
+ // The name of the structure that contains all the module-scope private variables.
+ Symbol private_struct_name;
+
+ // The name of the structure variable that contains all the module-scope private variables.
+ Symbol private_struct_variable_name;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
index a9aaa28..5991a1a 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param_test.cc
@@ -49,11 +49,17 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
@internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_1 : f32;
- tint_symbol = tint_symbol_1;
+ tint_symbol = tint_private_vars.p;
}
)";
@@ -74,11 +80,17 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
@internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_1 : f32;
- tint_symbol = tint_symbol_1;
+ tint_symbol = tint_private_vars.p;
}
)";
@@ -118,32 +130,38 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
fn no_uses() {
}
-fn zoo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<private, f32>) {
- *(tint_symbol) = (*(tint_symbol) * 2.0);
+fn zoo(tint_private_vars : ptr<private, tint_private_vars_struct>) {
+ (*(tint_private_vars)).p = ((*(tint_private_vars)).p * 2.0);
}
@internal(disable_validation__ignore_pointer_aliasing)
-fn bar(a : f32, b : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<private, f32>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_2 : ptr<workgroup, f32>) {
- *(tint_symbol_1) = a;
- *(tint_symbol_2) = b;
- zoo(tint_symbol_1);
+fn bar(a : f32, b : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ (*(tint_private_vars)).p = a;
+ *(tint_symbol) = b;
+ zoo(tint_private_vars);
}
@internal(disable_validation__ignore_pointer_aliasing)
-fn foo(a : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_3 : ptr<private, f32>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_4 : ptr<workgroup, f32>) {
+fn foo(a : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
- bar(a, b, tint_symbol_3, tint_symbol_4);
+ bar(a, b, tint_private_vars, tint_symbol_1);
no_uses();
}
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_5 : f32;
- @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_6 : f32;
- foo(1.0, &(tint_symbol_5), &(tint_symbol_6));
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_2 : f32;
+ foo(1.0, &(tint_private_vars), &(tint_symbol_2));
}
)";
@@ -183,17 +201,23 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_5 : f32;
- @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_6 : f32;
- foo(1.0, &(tint_symbol_5), &(tint_symbol_6));
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_2 : f32;
+ foo(1.0, &(tint_private_vars), &(tint_symbol_2));
}
@internal(disable_validation__ignore_pointer_aliasing)
-fn foo(a : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_3 : ptr<private, f32>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_4 : ptr<workgroup, f32>) {
+fn foo(a : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
- bar(a, b, tint_symbol_3, tint_symbol_4);
+ bar(a, b, tint_private_vars, tint_symbol_1);
no_uses();
}
@@ -201,14 +225,14 @@
}
@internal(disable_validation__ignore_pointer_aliasing)
-fn bar(a : f32, b : f32, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_1 : ptr<private, f32>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol_2 : ptr<workgroup, f32>) {
- *(tint_symbol_1) = a;
- *(tint_symbol_2) = b;
- zoo(tint_symbol_1);
+fn bar(a : f32, b : f32, tint_private_vars : ptr<private, tint_private_vars_struct>, @internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<workgroup, f32>) {
+ (*(tint_private_vars)).p = a;
+ *(tint_symbol) = b;
+ zoo(tint_private_vars);
}
-fn zoo(@internal(disable_validation__ignore_address_space) @internal(disable_validation__ignore_invalid_pointer_argument) tint_symbol : ptr<private, f32>) {
- *(tint_symbol) = (*(tint_symbol) * 2.0);
+fn zoo(tint_private_vars : ptr<private, tint_private_vars_struct>) {
+ (*(tint_private_vars)).p = ((*(tint_private_vars)).p * 2.0);
}
)";
@@ -229,11 +253,19 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32 = 1.0;
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_1 : f32 = f32();
- let x : f32 = (tint_symbol + tint_symbol_1);
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.a = 1.0;
+ tint_private_vars.b = f32();
+ let x : f32 = (tint_private_vars.a + tint_private_vars.b);
}
)";
@@ -254,11 +286,19 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32 = 1.0;
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol_1 : f32 = f32();
- let x : f32 = (tint_symbol + tint_symbol_1);
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.a = 1.0;
+ tint_private_vars.b = f32();
+ let x : f32 = (tint_private_vars.a + tint_private_vars.b);
}
)";
@@ -282,12 +322,18 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32;
- @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_1 : f32;
- let p_ptr : ptr<private, f32> = &(tint_symbol);
- let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ let p_ptr : ptr<private, f32> = &(tint_private_vars.p);
+ let w_ptr : ptr<workgroup, f32> = &(tint_symbol);
let x : f32 = (*(p_ptr) + *(w_ptr));
*(p_ptr) = x;
}
@@ -313,12 +359,18 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+}
+
@compute @workgroup_size(1)
fn main() {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32;
- @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol_1 : f32;
- let p_ptr : ptr<private, f32> = &(tint_symbol);
- let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ @internal(disable_validation__ignore_address_space) var<workgroup> tint_symbol : f32;
+ let p_ptr : ptr<private, f32> = &(tint_private_vars.p);
+ let w_ptr : ptr<workgroup, f32> = &(tint_symbol);
let x : f32 = (*(p_ptr) + *(w_ptr));
*(p_ptr) = x;
}
@@ -1151,6 +1203,7 @@
var<private> p : f32;
var<workgroup> w : f32;
+var<private> p_with_init : f32 = 42;
@group(0) @binding(0)
var<uniform> ub : S;
@@ -1166,6 +1219,13 @@
)";
auto* expect = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+struct tint_private_vars_struct {
+ p : f32,
+ p_with_init : f32,
+}
+
struct S {
a : f32,
}
@@ -1180,16 +1240,22 @@
EXPECT_EQ(expect, str(got));
}
-// Test that a private variable that is only referenced by a single user-defined function is
-// promoted to a function scope variable, rather than passed as a parameter.
-TEST_F(ModuleScopeVarToEntryPointParamTest, PromotePrivateToFunctionScope) {
+TEST_F(ModuleScopeVarToEntryPointParamTest, MultiplePrivateVariables) {
auto* src = R"(
-var<private> p : f32;
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
-fn foo(a : f32) -> f32 {
- let x = p;
- p = x * a;
- return p;
+var<private> a : f32;
+var<private> b : f32 = 42;
+var<private> c : S = S(1, 2, 3);
+var<private> d : S;
+var<private> unused : f32;
+
+fn foo(x : f32) -> f32 {
+ return (a + b + c.a + d.c) * x;
}
@compute @workgroup_size(1)
@@ -1199,16 +1265,32 @@
)";
auto* expect = R"(
-fn foo(a : f32) -> f32 {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32;
- let x = tint_symbol;
- tint_symbol = (x * a);
- return tint_symbol;
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+struct tint_private_vars_struct {
+ a : f32,
+ b : f32,
+ c : S,
+ d : S,
+ unused : f32,
+}
+
+fn foo(x : f32, tint_private_vars : ptr<private, tint_private_vars_struct>) -> f32 {
+ return (((((*(tint_private_vars)).a + (*(tint_private_vars)).b) + (*(tint_private_vars)).c.a) + (*(tint_private_vars)).d.c) * x);
}
@compute @workgroup_size(1)
fn main() {
- _ = foo(1.0);
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.b = 42;
+ tint_private_vars.c = S(1, 2, 3);
+ _ = foo(1.0, &(tint_private_vars));
}
)";
@@ -1217,36 +1299,59 @@
EXPECT_EQ(expect, str(got));
}
-// Test that a private variable that is only referenced by a single user-defined function is
-// promoted to a function scope variable, rather than passed as a parameter.
-TEST_F(ModuleScopeVarToEntryPointParamTest, PromotePrivateToFunctionScope_OutOfOrder) {
+TEST_F(ModuleScopeVarToEntryPointParamTest, MultiplePrivateVariables_OutOfOrder) {
auto* src = R"(
-var<private> p : f32;
+var<private> a : f32;
+var<private> c : S = S(1, 2, 3);
+var<private> unused : f32;
@compute @workgroup_size(1)
fn main() {
_ = foo(1.0);
}
-fn foo(a : f32) -> f32 {
- let x = p;
- p = x * a;
- return p;
+fn foo(x : f32) -> f32 {
+ return (a + b + c.a + d.c) * x;
}
+var<private> b : f32 = 42;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
+}
+
+var<private> d : S;
)";
auto* expect = R"(
-@compute @workgroup_size(1)
-fn main() {
- _ = foo(1.0);
+enable chromium_experimental_full_ptr_parameters;
+
+struct S {
+ a : f32,
+ b : f32,
+ c : f32,
}
-fn foo(a : f32) -> f32 {
- @internal(disable_validation__ignore_address_space) var<private> tint_symbol : f32;
- let x = tint_symbol;
- tint_symbol = (x * a);
- return tint_symbol;
+struct tint_private_vars_struct {
+ a : f32,
+ c : S,
+ unused : f32,
+ b : f32,
+ d : S,
+}
+
+@compute @workgroup_size(1)
+fn main() {
+ @internal(disable_validation__ignore_address_space) var<private> tint_private_vars : tint_private_vars_struct;
+ tint_private_vars.c = S(1, 2, 3);
+ tint_private_vars.b = 42;
+ _ = foo(1.0, &(tint_private_vars));
+}
+
+fn foo(x : f32, tint_private_vars : ptr<private, tint_private_vars_struct>) -> f32 {
+ return (((((*(tint_private_vars)).a + (*(tint_private_vars)).b) + (*(tint_private_vars)).c.a) + (*(tint_private_vars)).d.c) * x);
}
)";
diff --git a/src/tint/transform/substitute_override_test.cc b/src/tint/transform/substitute_override_test.cc
index 1b4646f..5deab72 100644
--- a/src/tint/transform/substitute_override_test.cc
+++ b/src/tint/transform/substitute_override_test.cc
@@ -84,15 +84,16 @@
TEST_F(SubstituteOverrideTest, ImplicitId) {
auto* src = R"(
+enable f16;
+
override i_width: i32;
override i_height = 1i;
override f_width: f32;
override f_height = 1.f;
-// TODO(crbug.com/tint/1473)
-// override h_width: f16;
-// override h_height = 1.h;
+override h_width: f16;
+override h_height = 1.h;
override b_width: bool;
override b_height = true;
@@ -106,6 +107,8 @@
)";
auto* expect = R"(
+enable f16;
+
const i_width : i32 = 42i;
const i_height = 11i;
@@ -114,6 +117,10 @@
const f_height = 12.3999996185302734375f;
+const h_width : f16 = 9.3984375h;
+
+const h_height = 3.3984375h;
+
const b_width : bool = true;
const b_height = false;
@@ -131,10 +138,10 @@
cfg.map.insert({OverrideId{1}, 11.0});
cfg.map.insert({OverrideId{2}, 22.3});
cfg.map.insert({OverrideId{3}, 12.4});
- // cfg.map.insert({OverrideId{4}, 9.4});
- // cfg.map.insert({OverrideId{5}, 3.4});
- cfg.map.insert({OverrideId{4}, 1.0});
- cfg.map.insert({OverrideId{5}, 0.0});
+ cfg.map.insert({OverrideId{4}, 9.4});
+ cfg.map.insert({OverrideId{5}, 3.4});
+ cfg.map.insert({OverrideId{6}, 1.0});
+ cfg.map.insert({OverrideId{7}, 0.0});
DataMap data;
data.Add<SubstituteOverride::Config>(cfg);
@@ -153,9 +160,8 @@
@id(1) override f_width: f32;
@id(9) override f_height = 1.f;
-// TODO(crbug.com/tint/1473)
-// @id(2) override h_width: f16;
-// @id(8) override h_height = 1.h;
+@id(2) override h_width: f16;
+@id(8) override h_height = 1.h;
@id(3) override b_width: bool;
@id(7) override b_height = true;
@@ -179,6 +185,10 @@
const f_height = 12.3999996185302734375f;
+const h_width : f16 = 9.3984375h;
+
+const h_height = 3.3984375h;
+
const b_width : bool = true;
const b_height = false;
diff --git a/src/tint/writer/msl/generator_impl_builtin_test.cc b/src/tint/writer/msl/generator_impl_builtin_test.cc
index 6c70abe..9868d5e 100644
--- a/src/tint/writer/msl/generator_impl_builtin_test.cc
+++ b/src/tint/writer/msl/generator_impl_builtin_test.cc
@@ -1089,9 +1089,13 @@
T tint_dot3(vec<T,3> a, vec<T,3> b) {
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
}
+struct tint_private_vars_struct {
+ int3 v;
+};
+
kernel void test_function() {
- thread int3 tint_symbol = 0;
- int r = tint_dot3(tint_symbol, tint_symbol);
+ thread tint_private_vars_struct tint_private_vars = {};
+ int r = tint_dot3(tint_private_vars.v, tint_private_vars.v);
return;
}
diff --git a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc
index fa0c53f..f4b0578 100644
--- a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc
+++ b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc
@@ -527,7 +527,10 @@
gen.increment_indent();
ASSERT_TRUE(gen.Generate()) << gen.error();
- EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = 0.0f;\n"));
+ EXPECT_THAT(gen.result(), HasSubstr(R"(thread tint_private_vars_struct tint_private_vars = {};
+ float const tint_symbol = tint_private_vars.a;
+ return;
+)"));
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {