Report referenced pipeline overridable constants Adding this information to each entry point reported by the inspector. BUG=tint:855 Change-Id: I043e48afed1503a4267dc4cb198fb86245984551 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53820 Auto-Submit: Ryan Harrison <rharrison@chromium.org> Reviewed-by: Ben Clayton <bclayton@google.com> Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ryan Harrison <rharrison@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/inspector/entry_point.h b/src/inspector/entry_point.h index aac9a4c..b3570bb 100644 --- a/src/inspector/entry_point.h +++ b/src/inspector/entry_point.h
@@ -45,6 +45,13 @@ ComponentType component_type; }; +/// Reflection data about a pipeline overridable constant referenced by an entry +/// point +struct OverridableConstant { + /// Name of the constant + std::string name; +}; + /// Reflection data for an entry point in the shader. struct EntryPoint { /// Constructors @@ -71,6 +78,8 @@ std::vector<StageVariable> input_variables; /// List of the output variable accessed via this entry point. std::vector<StageVariable> output_variables; + /// List of the pipeline overridable constants accessed via this entry point. + std::vector<OverridableConstant> overridable_constants; /// @returns the size of the workgroup in {x,y,z} format std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index c0a100c..8fe3b32 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc
@@ -222,7 +222,6 @@ entry_point.output_variables); } - // TODO(crbug.com/tint/697): Remove this. for (auto* var : sem->ReferencedModuleVariables()) { auto* decl = var->Declaration(); @@ -231,32 +230,43 @@ continue; } - StageVariable stage_variable; - stage_variable.name = name; + // TODO(crbug.com/tint/697): Remove this. + { + StageVariable stage_variable; + stage_variable.name = name; - stage_variable.component_type = ComponentType::kUnknown; - auto* type = var->Type()->UnwrapRef(); - if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { - stage_variable.component_type = ComponentType::kFloat; - } else if (type->is_unsigned_scalar_or_vector()) { - stage_variable.component_type = ComponentType::kUInt; - } else if (type->is_signed_scalar_or_vector()) { - stage_variable.component_type = ComponentType::kSInt; + stage_variable.component_type = ComponentType::kUnknown; + auto* type = var->Type()->UnwrapRef(); + if (type->is_float_scalar_or_vector() || type->is_float_matrix()) { + stage_variable.component_type = ComponentType::kFloat; + } else if (type->is_unsigned_scalar_or_vector()) { + stage_variable.component_type = ComponentType::kUInt; + } else if (type->is_signed_scalar_or_vector()) { + stage_variable.component_type = ComponentType::kSInt; + } + + auto* location_decoration = + ast::GetDecoration<ast::LocationDecoration>(decl->decorations()); + if (location_decoration) { + stage_variable.has_location_decoration = true; + stage_variable.location_decoration = location_decoration->value(); + } else { + stage_variable.has_location_decoration = false; + } + + if (var->StorageClass() == ast::StorageClass::kInput) { + entry_point.input_variables.push_back(stage_variable); + } else if (var->StorageClass() == ast::StorageClass::kOutput) { + entry_point.output_variables.push_back(stage_variable); + } } - auto* location_decoration = - ast::GetDecoration<ast::LocationDecoration>(decl->decorations()); - if (location_decoration) { - stage_variable.has_location_decoration = true; - stage_variable.location_decoration = location_decoration->value(); - } else { - stage_variable.has_location_decoration = false; - } - - if (var->StorageClass() == ast::StorageClass::kInput) { - entry_point.input_variables.push_back(stage_variable); - } else if (var->StorageClass() == ast::StorageClass::kOutput) { - entry_point.output_variables.push_back(stage_variable); + { + if (var->IsPipelineConstant()) { + OverridableConstant overridable_constant; + overridable_constant.name = name; + entry_point.overridable_constants.push_back(overridable_constant); + } } }
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 9949452..0e050f6 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc
@@ -148,10 +148,10 @@ /// will be added. /// @returns the constant that was created template <class T> - ast::Variable* AddConstantWithID(std::string name, - uint32_t id, - ast::Type* type, - T* val) { + ast::Variable* AddOverridableConstantWithID(std::string name, + uint32_t id, + ast::Type* type, + T* val) { ast::Expression* constructor = nullptr; if (val) { constructor = Expr(*val); @@ -169,9 +169,9 @@ /// will be added. /// @returns the constant that was created template <class T> - ast::Variable* AddConstantWithoutID(std::string name, - ast::Type* type, - T* val) { + ast::Variable* AddOverridableConstantWithoutID(std::string name, + ast::Type* type, + T* val) { ast::Expression* constructor = nullptr; if (val) { constructor = Expr(*val); @@ -182,6 +182,25 @@ }); } + /// Generates a function that references module constant + /// @param func name of the function created + /// @param var name of the constant to be reference + /// @param type type of the const being referenced + /// @param decorations the function decorations + /// @returns a function object + ast::Function* MakeConstReferenceBodyFunction( + std::string func, + std::string var, + ast::Type* type, + ast::DecorationList decorations) { + ast::StatementList stmts; + stmts.emplace_back(Decl(Var("local_" + var, type))); + stmts.emplace_back(Assign("local_" + var, var)); + stmts.emplace_back(Return()); + + return Func(func, ast::VariableList(), ty.void_(), stmts, decorations); + } + /// @param vec Vector of StageVariable to be searched /// @param name Name to be searching for /// @returns true if name is in vec, otherwise false @@ -1446,6 +1465,81 @@ EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[0].component_type); } +TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) { + AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr); + MakeEmptyBodyFunction("ep_func", {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_EQ(0u, result[0].overridable_constants.size()); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { + AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr); + MakeConstReferenceBodyFunction("ep_func", "foo", ty.f32(), + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + tint::writer::wgsl::Generator writer(program_.get()); + writer.Generate(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { + AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr); + MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); + MakeCallerBodyFunction("ep_func", {"callee_func"}, + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { + AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr); + AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr); + MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); + MakeCallerBodyFunction("ep_func", {"callee_func"}, + {Stage(ast::PipelineStage::kCompute)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overridable_constants.size()); + EXPECT_EQ("foo", result[0].overridable_constants[0].name); +} + +TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { + ast::Struct* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()}); + AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0); + MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}}); + MakeCallerBodyFunction("ep_func", {"ub_func"}, + {Stage(ast::PipelineStage::kFragment)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_EQ(0u, result[0].overridable_constants.size()); +} + // TODO(rharrison): Reenable once GetRemappedNameForEntryPoint isn't a pass // through TEST_F(InspectorGetRemappedNameForEntryPointTest, DISABLED_NoFunctions) { @@ -1518,9 +1612,9 @@ TEST_F(InspectorGetConstantIDsTest, Bool) { bool val_true = true; bool val_false = false; - AddConstantWithID<bool>("foo", 1, ty.bool_(), nullptr); - AddConstantWithID<bool>("bar", 20, ty.bool_(), &val_true); - AddConstantWithID<bool>("baz", 300, ty.bool_(), &val_false); + AddOverridableConstantWithID<bool>("foo", 1, ty.bool_(), nullptr); + AddOverridableConstantWithID<bool>("bar", 20, ty.bool_(), &val_true); + AddOverridableConstantWithID<bool>("baz", 300, ty.bool_(), &val_false); Inspector& inspector = Build(); @@ -1541,8 +1635,8 @@ TEST_F(InspectorGetConstantIDsTest, U32) { uint32_t val = 42; - AddConstantWithID<uint32_t>("foo", 1, ty.u32(), nullptr); - AddConstantWithID<uint32_t>("bar", 20, ty.u32(), &val); + AddOverridableConstantWithID<uint32_t>("foo", 1, ty.u32(), nullptr); + AddOverridableConstantWithID<uint32_t>("bar", 20, ty.u32(), &val); Inspector& inspector = Build(); @@ -1560,9 +1654,9 @@ TEST_F(InspectorGetConstantIDsTest, I32) { int32_t val_neg = -42; int32_t val_pos = 42; - AddConstantWithID<int32_t>("foo", 1, ty.i32(), nullptr); - AddConstantWithID<int32_t>("bar", 20, ty.i32(), &val_neg); - AddConstantWithID<int32_t>("baz", 300, ty.i32(), &val_pos); + AddOverridableConstantWithID<int32_t>("foo", 1, ty.i32(), nullptr); + AddOverridableConstantWithID<int32_t>("bar", 20, ty.i32(), &val_neg); + AddOverridableConstantWithID<int32_t>("baz", 300, ty.i32(), &val_pos); Inspector& inspector = Build(); @@ -1585,10 +1679,10 @@ float val_zero = 0.0f; float val_neg = -10.0f; float val_pos = 15.0f; - AddConstantWithID<float>("foo", 1, ty.f32(), nullptr); - AddConstantWithID<float>("bar", 20, ty.f32(), &val_zero); - AddConstantWithID<float>("baz", 300, ty.f32(), &val_neg); - AddConstantWithID<float>("x", 4000, ty.f32(), &val_pos); + AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr); + AddOverridableConstantWithID<float>("bar", 20, ty.f32(), &val_zero); + AddOverridableConstantWithID<float>("baz", 300, ty.f32(), &val_neg); + AddOverridableConstantWithID<float>("x", 4000, ty.f32(), &val_pos); Inspector& inspector = Build(); @@ -1612,12 +1706,12 @@ } TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { - AddConstantWithID<float>("v1", 1, ty.f32(), nullptr); - AddConstantWithID<float>("v20", 20, ty.f32(), nullptr); - AddConstantWithID<float>("v300", 300, ty.f32(), nullptr); - auto* a = AddConstantWithoutID<float>("a", ty.f32(), nullptr); - auto* b = AddConstantWithoutID<float>("b", ty.f32(), nullptr); - auto* c = AddConstantWithoutID<float>("c", ty.f32(), nullptr); + AddOverridableConstantWithID<float>("v1", 1, ty.f32(), nullptr); + AddOverridableConstantWithID<float>("v20", 20, ty.f32(), nullptr); + AddOverridableConstantWithID<float>("v300", 300, ty.f32(), nullptr); + auto* a = AddOverridableConstantWithoutID<float>("a", ty.f32(), nullptr); + auto* b = AddOverridableConstantWithoutID<float>("b", ty.f32(), nullptr); + auto* c = AddOverridableConstantWithoutID<float>("c", ty.f32(), nullptr); Inspector& inspector = Build();
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 41c4114..e80d11d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -135,8 +135,8 @@ if (current_function_ == nullptr) { return; } - if (var->storage_class == ast::StorageClass::kNone || - var->storage_class == ast::StorageClass::kFunction) { + + if (var->kind != VariableKind::kGlobal) { return; } @@ -496,7 +496,7 @@ } auto* info = variable_infos_.Create(var, const_cast<sem::Type*>(type), - type_name, storage_class, access); + type_name, storage_class, access, kind); variable_to_info_.emplace(var, info); return info; @@ -3377,12 +3377,14 @@ sem::Type* ty, const std::string& tn, ast::StorageClass sc, - ast::Access ac) + ast::Access ac, + VariableKind k) : declaration(decl), type(ty), type_name(tn), storage_class(sc), - access(ac) {} + access(ac), + kind(k) {} Resolver::VariableInfo::~VariableInfo() = default;
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 9f91027..b58d2b5 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h
@@ -86,6 +86,9 @@ bool IsHostShareable(const sem::Type* type); private: + /// Describes the context in which a variable is declared + enum class VariableKind { kParameter, kLocal, kGlobal }; + /// Structure holding semantic information about a variable. /// Used to build the sem::Variable nodes at the end of resolving. struct VariableInfo { @@ -93,7 +96,8 @@ sem::Type* type, const std::string& type_name, ast::StorageClass storage_class, - ast::Access ac); + ast::Access ac, + VariableKind k); ~VariableInfo(); ast::Variable const* const declaration; @@ -103,6 +107,7 @@ ast::Access const access; std::vector<ast::IdentifierExpression*> users; sem::BindingPoint binding_point; + VariableKind kind; }; struct IntrinsicCallInfo { @@ -190,9 +195,6 @@ sem::Type* const sem; }; - /// Describes the context in which a variable is declared - enum class VariableKind { kParameter, kLocal, kGlobal }; - /// Resolves the program, without creating final the semantic nodes. /// @returns true on success, false on error bool ResolveInternal();
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index e060682..1250deb 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc
@@ -23,6 +23,7 @@ #include "src/ast/break_statement.h" #include "src/ast/call_statement.h" #include "src/ast/continue_statement.h" +#include "src/ast/float_literal.h" #include "src/ast/if_statement.h" #include "src/ast/intrinsic_texture_helper_test.h" #include "src/ast/loop_statement.h" @@ -903,6 +904,33 @@ EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>()); } +TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) { + auto* func = Func("my_func", ast::VariableList{}, ty.void_(), + { + Decl(Const("var", ty.f32(), Construct(ty.f32()))), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); + EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>()); +} + +TEST_F(ResolverTest, Function_NotRegisterFunctionParams) { + auto* func = Func("my_func", {Const("var", ty.f32(), Construct(ty.f32()))}, + ty.void_(), {}); + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); + EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>()); +} + TEST_F(ResolverTest, Function_ReturnStatements) { auto* var = Var("foo", ty.f32());