Move workgroup_size property into sem::Function
The workgroup size should not be a property of the function in the
AST, and this lays the groundwork for allowing both literals and
module-scope constants to be used for this attribute.
Bug: tint:713
Change-Id: I014be879e2adb81cfc5b0ea0e221035fae626223
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51261
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index a80472b..93266aa 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1287,6 +1287,20 @@
Mark(deco);
}
+ // Set work-group size defaults.
+ for (int i = 0; i < 3; i++) {
+ info->workgroup_size[i].value = 1;
+ info->workgroup_size[i].overridable_const = nullptr;
+ }
+
+ if (auto* workgroup =
+ ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
+ // TODO(crbug.com/tint/713): Handle non-literals.
+ info->workgroup_size[0].value = std::get<0>(workgroup->values());
+ info->workgroup_size[1].value = std::get<1>(workgroup->values());
+ info->workgroup_size[2].value = std::get<2>(workgroup->values());
+ }
+
if (!ValidateFunction(func, info)) {
return false;
}
@@ -2517,7 +2531,7 @@
info->declaration, const_cast<sem::Type*>(info->return_type),
remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars), info->return_statements,
- ancestor_entry_points[func->symbol()]);
+ ancestor_entry_points[func->symbol()], info->workgroup_size);
func_info_to_sem_func.emplace(info, sem_func);
sem.Add(func, sem_func);
}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 76a722f..ad428a8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -26,6 +26,7 @@
#include "src/scope_stack.h"
#include "src/sem/binding_point.h"
#include "src/sem/block_statement.h"
+#include "src/sem/function.h"
#include "src/sem/struct.h"
#include "src/utils/unique_vector.h"
@@ -112,6 +113,7 @@
std::vector<const ast::ReturnStatement*> return_statements;
sem::Type* return_type = nullptr;
std::string return_type_name;
+ std::array<sem::WorkgroupDimension, 3> workgroup_size;
// List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index e81be39..109b31a 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -32,6 +32,7 @@
#include "src/ast/switch_statement.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
+#include "src/ast/workgroup_decoration.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
@@ -887,6 +888,40 @@
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
}
+TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 1u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 1u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ create<ast::WorkgroupDecoration>(8, 2, 3)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* st = Structure("S", {Member("first_member", ty.i32()),
Member("second_member", ty.f32())});