tint: Refactor Extensions / Enables. * Extract ast::Enable::ExtensionKind to ast::Extension. * Move the parsing out of ast::Enable and next to ast/extension.h * Change the ast::Enable constructor to take the Extension, instead of a std::string. It's the WGSL parser's responsibility to parse, not the AST nodes. * Add ProgramBuilder::Enable() helper. * Keep ast::Module simple - keep track of the declared AST Enable nodes, don't do any deduplicating of the enabled extensions. * Add the de-duplicated ast::Extensions to the sem::Module. * Remove the kInternalExtensionForTesting enum value - we have kF16 now, which can be used instead for testing. * Rename kNoExtension to kNone. Bug: tint:1472 Change-Id: I9af635e95d36991ea468e6e0bf6798bb50937edc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90523 Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 57069ed..d6add45 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -663,7 +663,7 @@ // Test that WGSL extension used by enable directives must be allowed by WebGPU. TEST_F(ShaderModuleValidationTest, ExtensionMustBeAllowed) { ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( -enable InternalExtensionForTesting; +enable f16; @stage(compute) @workgroup_size(1) fn main() {})")); }
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 0cacca4..2f4fd62 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn
@@ -226,6 +226,8 @@ "ast/enable.h", "ast/expression.cc", "ast/expression.h", + "ast/extension.cc", + "ast/extension.h", "ast/external_texture.cc", "ast/external_texture.h", "ast/f16.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index f4c3884..2c03e81 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt
@@ -114,6 +114,8 @@ ast/enable.h ast/expression.cc ast/expression.h + ast/extension.cc + ast/extension.h ast/external_texture.cc ast/external_texture.h ast/f16.cc @@ -691,6 +693,7 @@ ast/depth_texture_test.cc ast/discard_statement_test.cc ast/enable_test.cc + ast/extension_test.cc ast/external_texture_test.cc ast/f16_test.cc ast/f32_test.cc
diff --git a/src/tint/ast/enable.cc b/src/tint/ast/enable.cc index 200c2be..ef43200 100644 --- a/src/tint/ast/enable.cc +++ b/src/tint/ast/enable.cc
@@ -21,47 +21,7 @@ namespace tint::ast { -Enable::ExtensionKind Enable::NameToKind(const std::string& name) { - if (name == "chromium_experimental_dp4a") { - return Enable::ExtensionKind::kChromiumExperimentalDP4a; - } - if (name == "chromium_disable_uniformity_analysis") { - return Enable::ExtensionKind::kChromiumDisableUniformityAnalysis; - } - if (name == "f16") { - return Enable::ExtensionKind::kF16; - } - - // The reserved internal extension name for testing - if (name == "InternalExtensionForTesting") { - return Enable::ExtensionKind::kInternalExtensionForTesting; - } - - return Enable::ExtensionKind::kNoExtension; -} - -std::string Enable::KindToName(ExtensionKind kind) { - switch (kind) { - case ExtensionKind::kChromiumExperimentalDP4a: - return "chromium_experimental_dp4a"; - case ExtensionKind::kChromiumDisableUniformityAnalysis: - return "chromium_disable_uniformity_analysis"; - case ExtensionKind::kF16: - return "f16"; - // The reserved internal extension for testing - case ExtensionKind::kInternalExtensionForTesting: - return "InternalExtensionForTesting"; - case ExtensionKind::kNoExtension: - // Return an empty string for kNoExtension - return {}; - // No default case, as this switch must cover all ExtensionKind values. - } - // This return shall never get hit. - return {}; -} - -Enable::Enable(ProgramID pid, const Source& src, const std::string& ext_name) - : Base(pid, src), name(ext_name), kind(NameToKind(ext_name)) {} +Enable::Enable(ProgramID pid, const Source& src, Extension ext) : Base(pid, src), extension(ext) {} Enable::Enable(Enable&&) = default; @@ -69,6 +29,6 @@ const Enable* Enable::Clone(CloneContext* ctx) const { auto src = ctx->Clone(source); - return ctx->dst->create<Enable>(src, name); + return ctx->dst->create<Enable>(src, extension); } } // namespace tint::ast
diff --git a/src/tint/ast/enable.h b/src/tint/ast/enable.h index f190b0a..674d9cb 100644 --- a/src/tint/ast/enable.h +++ b/src/tint/ast/enable.h
@@ -16,63 +16,26 @@ #define SRC_TINT_AST_ENABLE_H_ #include <string> -#include <unordered_set> #include <utility> +#include <vector> -#include "src/tint/ast/access.h" -#include "src/tint/ast/expression.h" +#include "src/tint/ast/extension.h" +#include "src/tint/ast/node.h" namespace tint::ast { -/// An instance of this class represents one extension mentioned in a -/// "enable" derictive. Example: -/// // Enable an extension named "f16" -/// enable f16; -class Enable : public Castable<Enable, Node> { +/// An "enable" directive. Example: +/// ``` +/// // Enable an extension named "f16" +/// enable f16; +/// ``` +class Enable final : public Castable<Enable, Node> { public: - /// The enum class identifing each supported WGSL extension - enum class ExtensionKind { - /// An internal reserved extension for test, named - /// "InternalExtensionForTesting". - kInternalExtensionForTesting, - /// WGSL Extension "f16" - kF16, - - /// An extension for the experimental feature - /// "chromium_experimental_dp4a". - /// See crbug.com/tint/1497 for more details - kChromiumExperimentalDP4a, - /// A Chromium-specific extension for disabling uniformity analysis. - kChromiumDisableUniformityAnalysis, - - /// Reserved for representing "No extension required" or "Not a valid extension". - kNoExtension, - }; - - /// Convert a string of extension name into one of ExtensionKind enum value, - /// the result will be ExtensionKind::kNoExtension if the name is not a - /// known extension name. A extension node of kind kNoExtension must not - /// exist in the AST tree, and using a unknown extension name in WGSL code - /// should result in a shader-creation error. - /// @param name string of the extension name - /// @return the ExtensionKind enum value for the extension of given name, or - /// kNoExtension if no known extension has the given name - static ExtensionKind NameToKind(const std::string& name); - - /// Convert the ExtensionKind enum value to corresponding extension name - /// string. If the given enum value is kNoExtension or don't have a known - /// name, return an empty string instead. - /// @param kind the ExtensionKind enum value - /// @return string of the extension name corresponding to the given kind, or - /// an empty string if the given enum value is kNoExtension or don't have a - /// known corresponding name - static std::string KindToName(ExtensionKind kind); - /// Create a extension /// @param pid the identifier of the program that owns this node /// @param src the source of this node - /// @param name the name of extension - Enable(ProgramID pid, const Source& src, const std::string& name); + /// @param ext the extension + Enable(ProgramID pid, const Source& src, Extension ext); /// Move constructor Enable(Enable&&); @@ -85,14 +48,11 @@ const Enable* Clone(CloneContext* ctx) const override; /// The extension name - const std::string name; - - /// The extension kind - const ExtensionKind kind; + const Extension extension; }; -/// A set of extension kinds -using ExtensionSet = std::unordered_set<Enable::ExtensionKind>; +/// A list of enables +using EnableList = std::vector<const Enable*>; } // namespace tint::ast
diff --git a/src/tint/ast/enable_test.cc b/src/tint/ast/enable_test.cc index 208c85d..e8b6e5c 100644 --- a/src/tint/ast/enable_test.cc +++ b/src/tint/ast/enable_test.cc
@@ -19,40 +19,15 @@ namespace tint::ast { namespace { -using AstExtensionTest = TestHelper; +using EnableTest = TestHelper; -TEST_F(AstExtensionTest, Creation) { - auto* ext = - create<Enable>(Source{Source::Range{Source::Location{20, 2}, Source::Location{20, 5}}}, - "InternalExtensionForTesting"); +TEST_F(EnableTest, Creation) { + auto* ext = create<ast::Enable>(Source{{{20, 2}, {20, 5}}}, Extension::kF16); EXPECT_EQ(ext->source.range.begin.line, 20u); EXPECT_EQ(ext->source.range.begin.column, 2u); EXPECT_EQ(ext->source.range.end.line, 20u); EXPECT_EQ(ext->source.range.end.column, 5u); - EXPECT_EQ(ext->kind, ast::Enable::ExtensionKind::kInternalExtensionForTesting); -} - -TEST_F(AstExtensionTest, Creation_InvalidName) { - auto* ext = create<Enable>( - Source{Source::Range{Source::Location{20, 2}, Source::Location{20, 5}}}, std::string()); - EXPECT_EQ(ext->source.range.begin.line, 20u); - EXPECT_EQ(ext->source.range.begin.column, 2u); - EXPECT_EQ(ext->source.range.end.line, 20u); - EXPECT_EQ(ext->source.range.end.column, 5u); - EXPECT_EQ(ext->kind, ast::Enable::ExtensionKind::kNoExtension); -} - -TEST_F(AstExtensionTest, NameToKind_InvalidName) { - EXPECT_EQ(ast::Enable::NameToKind(std::string()), ast::Enable::ExtensionKind::kNoExtension); - EXPECT_EQ(ast::Enable::NameToKind("__ImpossibleExtensionName"), - ast::Enable::ExtensionKind::kNoExtension); - EXPECT_EQ(ast::Enable::NameToKind("123"), ast::Enable::ExtensionKind::kNoExtension); -} - -TEST_F(AstExtensionTest, KindToName) { - EXPECT_EQ(ast::Enable::KindToName(ast::Enable::ExtensionKind::kInternalExtensionForTesting), - "InternalExtensionForTesting"); - EXPECT_EQ(ast::Enable::KindToName(ast::Enable::ExtensionKind::kNoExtension), std::string()); + EXPECT_EQ(ext->extension, Extension::kF16); } } // namespace
diff --git a/src/tint/ast/extension.cc b/src/tint/ast/extension.cc new file mode 100644 index 0000000..f03e3a0 --- /dev/null +++ b/src/tint/ast/extension.cc
@@ -0,0 +1,51 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/extension.h" + +namespace tint::ast { + +Extension ParseExtension(const std::string& name) { + if (name == "chromium_experimental_dp4a") { + return Extension::kChromiumExperimentalDP4a; + } + if (name == "chromium_disable_uniformity_analysis") { + return Extension::kChromiumDisableUniformityAnalysis; + } + if (name == "f16") { + return Extension::kF16; + } + return Extension::kNone; +} + +const char* str(Extension ext) { + switch (ext) { + case Extension::kChromiumExperimentalDP4a: + return "chromium_experimental_dp4a"; + case Extension::kChromiumDisableUniformityAnalysis: + return "chromium_disable_uniformity_analysis"; + case Extension::kF16: + return "f16"; + case Extension::kNone: + return "<none>"; + } + return "<unknown>"; +} + +std::ostream& operator<<(std::ostream& out, Extension i) { + out << str(i); + return out; +} + +} // namespace tint::ast
diff --git a/src/tint/ast/extension.h b/src/tint/ast/extension.h new file mode 100644 index 0000000..21e9ac1 --- /dev/null +++ b/src/tint/ast/extension.h
@@ -0,0 +1,68 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_EXTENSION_H_ +#define SRC_TINT_AST_EXTENSION_H_ + +#include <sstream> +#include <string> + +#include "src/tint/utils/unique_vector.h" + +namespace tint::ast { + +/// An enumerator of WGSL extensions +enum class Extension { + /// WGSL Extension "f16" + kF16, + + /// An extension for the experimental feature + /// "chromium_experimental_dp4a". + /// See crbug.com/tint/1497 for more details + kChromiumExperimentalDP4a, + /// A Chromium-specific extension for disabling uniformity analysis. + kChromiumDisableUniformityAnalysis, + + /// Reserved for representing "No extension required" or "Not a valid extension". + kNone, +}; + +/// Convert a string of extension name into one of Extension enum value, the result will be +/// Extension::kNone if the name is not a known extension name. A extension node of kind +/// kNone must not exist in the AST tree, and using a unknown extension name in WGSL code +/// should result in a shader-creation error. +/// @param name string of the extension name +/// @return the Extension enum value for the extension of given name, or kNone if no known extension +/// has the given name +Extension ParseExtension(const std::string& name); + +/// Convert the Extension enum value to corresponding extension name string. +/// @param ext the Extension enum value +/// @return string of the extension name corresponding to the given kind, or +/// an empty string if the given enum value is kNone or don't have a +/// known corresponding name +const char* ExtensionName(Extension ext); + +/// @returns the name of the extension. +const char* str(Extension i); + +/// Emits the name of the extension type. +std::ostream& operator<<(std::ostream& out, Extension i); + +// A unique vector of extensions +using Extensions = utils::UniqueVector<Extension>; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_EXTENSION_H_
diff --git a/src/tint/ast/extension_test.cc b/src/tint/ast/extension_test.cc new file mode 100644 index 0000000..ed27674b --- /dev/null +++ b/src/tint/ast/extension_test.cc
@@ -0,0 +1,36 @@ + +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/extension.h" + +#include "gtest/gtest.h" + +namespace tint::ast { +namespace { + +TEST(ExtensionTest, NameToKind_InvalidName) { + EXPECT_EQ(ParseExtension("f16"), Extension::kF16); + EXPECT_EQ(ParseExtension(""), Extension::kNone); + EXPECT_EQ(ParseExtension("__ImpossibleExtensionName"), Extension::kNone); + EXPECT_EQ(ParseExtension("123"), Extension::kNone); +} + +TEST(ExtensionTest, KindToName) { + EXPECT_EQ(std::string(str(Extension::kF16)), "f16"); + EXPECT_EQ(std::string(str(Extension::kNone)), "<none>"); +} + +} // namespace +} // namespace tint::ast
diff --git a/src/tint/ast/module.cc b/src/tint/ast/module.cc index e163c19..40dff98 100644 --- a/src/tint/ast/module.cc +++ b/src/tint/ast/module.cc
@@ -68,18 +68,18 @@ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id); global_variables_.push_back(var); }, - [&](const Enable* ext) { - TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, ext, program_id); - extensions_.insert(ext->kind); + [&](const Enable* enable) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, enable, program_id); + enables_.push_back(enable); }, [&](Default) { TINT_ICE(AST, diags) << "Unknown global declaration type"; }); } -void Module::AddEnable(const ast::Enable* ext) { - TINT_ASSERT(AST, ext); - TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, ext, program_id); - global_declarations_.push_back(ext); - extensions_.insert(ext->kind); +void Module::AddEnable(const ast::Enable* enable) { + TINT_ASSERT(AST, enable); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, enable, program_id); + global_declarations_.push_back(enable); + enables_.push_back(enable); } void Module::AddGlobalVariable(const ast::Variable* var) { @@ -117,7 +117,7 @@ type_decls_.clear(); functions_.clear(); global_variables_.clear(); - extensions_.clear(); + enables_.clear(); for (auto* decl : global_declarations_) { if (!decl) {
diff --git a/src/tint/ast/module.h b/src/tint/ast/module.h index d8be2ed..45b1ec6 100644 --- a/src/tint/ast/module.h +++ b/src/tint/ast/module.h
@@ -78,7 +78,7 @@ VariableList& GlobalVariables() { return global_variables_; } /// @returns the extension set for the module - const ExtensionSet& Extensions() const { return extensions_; } + const EnableList& Enables() const { return enables_; } /// Adds a type declaration to the Builder. /// @param decl the type declaration to add @@ -120,7 +120,7 @@ std::vector<const TypeDecl*> type_decls_; FunctionList functions_; VariableList global_variables_; - ExtensionSet extensions_; + EnableList enables_; }; } // namespace tint::ast
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 2bfe815..9a2afb8 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc
@@ -19,6 +19,7 @@ #include "src/tint/ast/bool_literal_expression.h" #include "src/tint/ast/call_expression.h" +#include "src/tint/ast/extension.h" #include "src/tint/ast/float_literal_expression.h" #include "src/tint/ast/id_attribute.h" #include "src/tint/ast/interpolate_attribute.h" @@ -32,6 +33,7 @@ #include "src/tint/sem/function.h" #include "src/tint/sem/i32.h" #include "src/tint/sem/matrix.h" +#include "src/tint/sem/module.h" #include "src/tint/sem/multisampled_texture.h" #include "src/tint/sem/sampled_texture.h" #include "src/tint/sem/statement.h" @@ -544,16 +546,13 @@ } std::vector<std::string> Inspector::GetUsedExtensionNames() { - std::vector<std::string> result; - - ast::ExtensionSet set = program_->AST().Extensions(); - result.reserve(set.size()); - for (auto kind : set) { - std::string name = ast::Enable::KindToName(kind); - result.push_back(name); + auto& extensions = program_->Sem().Module()->Extensions(); + std::vector<std::string> out; + out.reserve(extensions.size()); + for (auto ext : extensions) { + out.push_back(ast::str(ext)); } - - return result; + return out; } std::vector<std::pair<std::string, Source>> Inspector::GetEnableDirectives() { @@ -563,7 +562,7 @@ auto global_decls = program_->AST().GlobalDeclarations(); for (auto* node : global_decls) { if (auto* ext = node->As<ast::Enable>()) { - result.push_back({ext->name, ext->source}); + result.push_back({ast::str(ext->extension), ext->source}); } }
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc index 18a33b1..033f5b8 100644 --- a/src/tint/inspector/inspector_test.cc +++ b/src/tint/inspector/inspector_test.cc
@@ -2849,7 +2849,7 @@ // Test calling GetUsedExtensionNames on a shader with valid extension. TEST_F(InspectorGetUsedExtensionNamesTest, Simple) { std::string shader = R"( -enable InternalExtensionForTesting; +enable f16; @stage(fragment) fn main() { @@ -2859,15 +2859,15 @@ auto result = inspector.GetUsedExtensionNames(); EXPECT_EQ(result.size(), 1u); - EXPECT_EQ(result[0], "InternalExtensionForTesting"); + EXPECT_EQ(result[0], "f16"); } // Test calling GetUsedExtensionNames on a shader with a extension enabled for // multiple times. TEST_F(InspectorGetUsedExtensionNamesTest, Duplicated) { std::string shader = R"( -enable InternalExtensionForTesting; -enable InternalExtensionForTesting; +enable f16; +enable f16; @stage(fragment) fn main() { @@ -2877,7 +2877,7 @@ auto result = inspector.GetUsedExtensionNames(); EXPECT_EQ(result.size(), 1u); - EXPECT_EQ(result[0], "InternalExtensionForTesting"); + EXPECT_EQ(result[0], "f16"); } // Test calling GetEnableDirectives on a empty shader. @@ -2906,7 +2906,7 @@ // Test calling GetEnableDirectives on a shader with valid extension. TEST_F(InspectorGetEnableDirectivesTest, Simple) { std::string shader = R"( -enable InternalExtensionForTesting; +enable f16; @stage(fragment) fn main() { @@ -2916,17 +2916,17 @@ auto result = inspector.GetEnableDirectives(); EXPECT_EQ(result.size(), 1u); - EXPECT_EQ(result[0].first, "InternalExtensionForTesting"); - EXPECT_EQ(result[0].second.range, (Source::Range{{2, 8}, {2, 35}})); + EXPECT_EQ(result[0].first, "f16"); + EXPECT_EQ(result[0].second.range, (Source::Range{{2, 8}, {2, 11}})); } // Test calling GetEnableDirectives on a shader with a extension enabled for // multiple times. TEST_F(InspectorGetEnableDirectivesTest, Duplicated) { std::string shader = R"( -enable InternalExtensionForTesting; +enable f16; -enable InternalExtensionForTesting; +enable f16; @stage(fragment) fn main() { })"; @@ -2935,10 +2935,10 @@ auto result = inspector.GetEnableDirectives(); EXPECT_EQ(result.size(), 2u); - EXPECT_EQ(result[0].first, "InternalExtensionForTesting"); - EXPECT_EQ(result[0].second.range, (Source::Range{{2, 8}, {2, 35}})); - EXPECT_EQ(result[1].first, "InternalExtensionForTesting"); - EXPECT_EQ(result[1].second.range, (Source::Range{{4, 8}, {4, 35}})); + EXPECT_EQ(result[0].first, "f16"); + EXPECT_EQ(result[0].second.range, (Source::Range{{2, 8}, {2, 11}})); + EXPECT_EQ(result[1].first, "f16"); + EXPECT_EQ(result[1].second.range, (Source::Range{{4, 8}, {4, 11}})); } // Crash was occuring in ::GenerateSamplerTargets, when
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index a91d87c..32c0e93 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h
@@ -39,6 +39,7 @@ #include "src/tint/ast/disable_validation_attribute.h" #include "src/tint/ast/discard_statement.h" #include "src/tint/ast/enable.h" +#include "src/tint/ast/extension.h" #include "src/tint/ast/external_texture.h" #include "src/tint/ast/f16.h" #include "src/tint/ast/f32.h" @@ -1307,6 +1308,15 @@ return Construct(ty.array(subtype, std::forward<EXPR>(n)), std::forward<ARGS>(args)...); } + /// Adds the extension to the list of enable directives at the top of the module. + /// @param ext the extension to enable + /// @return an `ast::Enable` enabling the given extension. + const ast::Enable* Enable(ast::Extension ext) { + auto* enable = create<ast::Enable>(ext); + AST().AddEnable(enable); + return enable; + } + /// @param name the variable name /// @param type the variable type /// @param optional the optional variable settings.
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index abd97e5..4f0b152 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -366,13 +366,11 @@ return Failure::kErrored; } - if (ast::Enable::NameToKind(name.value) != ast::Enable::ExtensionKind::kNoExtension) { - const ast::Enable* extension = create<ast::Enable>(name.source, name.value); - builder_.AST().AddEnable(extension); - } else { - // Error if an unknown extension is used + auto extension = ast::ParseExtension(name.value); + if (extension == ast::Extension::kNone) { return add_error(name.source, "unsupported extension: '" + name.value + "'"); } + builder_.AST().AddEnable(create<ast::Enable>(name.source, extension)); return true; });
diff --git a/src/tint/reader/wgsl/parser_impl_enable_directive_test.cc b/src/tint/reader/wgsl/parser_impl_enable_directive_test.cc index 0fd6b80..bdf7eeb 100644 --- a/src/tint/reader/wgsl/parser_impl_enable_directive_test.cc +++ b/src/tint/reader/wgsl/parser_impl_enable_directive_test.cc
@@ -23,41 +23,36 @@ // Test a valid enable directive. TEST_F(EnableDirectiveTest, Valid) { - auto p = parser("enable InternalExtensionForTesting;"); + auto p = parser("enable f16;"); p->enable_directive(); EXPECT_FALSE(p->has_error()) << p->error(); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions(), - ast::ExtensionSet{ast::Enable::ExtensionKind::kInternalExtensionForTesting}); - EXPECT_EQ(ast.GlobalDeclarations().size(), 1u); - auto* node = ast.GlobalDeclarations()[0]->As<ast::Enable>(); - EXPECT_TRUE(node != nullptr); - EXPECT_EQ(node->name, "InternalExtensionForTesting"); - EXPECT_EQ(node->kind, ast::Enable::ExtensionKind::kInternalExtensionForTesting); + ASSERT_EQ(ast.Enables().size(), 1u); + auto* enable = ast.Enables()[0]; + EXPECT_EQ(enable->extension, ast::Extension::kF16); + ASSERT_EQ(ast.GlobalDeclarations().size(), 1u); + EXPECT_EQ(ast.GlobalDeclarations()[0], enable); } // Test multiple enable directives for a same extension. TEST_F(EnableDirectiveTest, EnableMultipleTime) { auto p = parser(R"( -enable InternalExtensionForTesting; -enable InternalExtensionForTesting; +enable f16; +enable f16; )"); p->translation_unit(); EXPECT_FALSE(p->has_error()) << p->error(); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions(), - ast::ExtensionSet{ast::Enable::ExtensionKind::kInternalExtensionForTesting}); - EXPECT_EQ(ast.GlobalDeclarations().size(), 2u); - auto* node1 = ast.GlobalDeclarations()[0]->As<ast::Enable>(); - EXPECT_TRUE(node1 != nullptr); - EXPECT_EQ(node1->name, "InternalExtensionForTesting"); - EXPECT_EQ(node1->kind, ast::Enable::ExtensionKind::kInternalExtensionForTesting); - auto* node2 = ast.GlobalDeclarations()[1]->As<ast::Enable>(); - EXPECT_TRUE(node2 != nullptr); - EXPECT_EQ(node2->name, "InternalExtensionForTesting"); - EXPECT_EQ(node2->kind, ast::Enable::ExtensionKind::kInternalExtensionForTesting); + ASSERT_EQ(ast.Enables().size(), 2u); + auto* enable_a = ast.Enables()[0]; + auto* enable_b = ast.Enables()[1]; + EXPECT_EQ(enable_a->extension, ast::Extension::kF16); + EXPECT_EQ(enable_b->extension, ast::Extension::kF16); + ASSERT_EQ(ast.GlobalDeclarations().size(), 2u); + EXPECT_EQ(ast.GlobalDeclarations()[0], enable_a); + EXPECT_EQ(ast.GlobalDeclarations()[1], enable_b); } // Test an unknown extension identifier. @@ -69,42 +64,42 @@ EXPECT_EQ(p->error(), "1:8: unsupported extension: 'NotAValidExtensionName'"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } -// Test an enable directive missing ending semiclon. -TEST_F(EnableDirectiveTest, MissingEndingSemiclon) { - auto p = parser("enable InternalExtensionForTesting"); +// Test an enable directive missing ending semicolon. +TEST_F(EnableDirectiveTest, MissingEndingSemicolon) { + auto p = parser("enable f16"); p->translation_unit(); EXPECT_TRUE(p->has_error()); - EXPECT_EQ(p->error(), "1:35: expected ';' for enable directive"); + EXPECT_EQ(p->error(), "1:11: expected ';' for enable directive"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } // Test using invalid tokens in an enable directive. TEST_F(EnableDirectiveTest, InvalidTokens) { { - auto p = parser("enable InternalExtensionForTesting<;"); + auto p = parser("enable f16<;"); p->translation_unit(); EXPECT_TRUE(p->has_error()); - EXPECT_EQ(p->error(), "1:35: expected ';' for enable directive"); + EXPECT_EQ(p->error(), "1:11: expected ';' for enable directive"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } { - auto p = parser("enable <InternalExtensionForTesting;"); + auto p = parser("enable <f16;"); p->translation_unit(); EXPECT_TRUE(p->has_error()); EXPECT_EQ(p->error(), "1:8: invalid extension name"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } { @@ -114,7 +109,7 @@ EXPECT_EQ(p->error(), "1:8: invalid extension name"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } { @@ -124,7 +119,7 @@ EXPECT_EQ(p->error(), "1:8: invalid extension name"); auto program = p->program(); auto& ast = program.AST(); - EXPECT_EQ(ast.Extensions().size(), 0u); + EXPECT_EQ(ast.Enables().size(), 0u); EXPECT_EQ(ast.GlobalDeclarations().size(), 0u); } } @@ -133,35 +128,39 @@ TEST_F(EnableDirectiveTest, FollowingOtherGlobalDecl) { auto p = parser(R"( var<private> t: f32 = 0f; -enable InternalExtensionForTesting; +enable f16; )"); p->translation_unit(); EXPECT_TRUE(p->has_error()); EXPECT_EQ(p->error(), "3:1: enable directives must come before all global declarations"); auto program = p->program(); auto& ast = program.AST(); - // Accept the enable directive although it cause an error - EXPECT_EQ(ast.Extensions(), - ast::ExtensionSet{ast::Enable::ExtensionKind::kInternalExtensionForTesting}); - EXPECT_EQ(ast.GlobalDeclarations().size(), 2u); + // Accept the enable directive although it caused an error + ASSERT_EQ(ast.Enables().size(), 1u); + auto* enable = ast.Enables()[0]; + EXPECT_EQ(enable->extension, ast::Extension::kF16); + ASSERT_EQ(ast.GlobalDeclarations().size(), 2u); + EXPECT_EQ(ast.GlobalDeclarations()[1], enable); } -// Test an enable directive go after an empty semiclon. -TEST_F(EnableDirectiveTest, FollowingEmptySemiclon) { +// Test an enable directive go after an empty semicolon. +TEST_F(EnableDirectiveTest, FollowingEmptySemicolon) { auto p = parser(R"( ; -enable InternalExtensionForTesting; +enable f16; )"); p->translation_unit(); - // An empty semiclon is treated as a global declaration + // An empty semicolon is treated as a global declaration EXPECT_TRUE(p->has_error()); EXPECT_EQ(p->error(), "3:1: enable directives must come before all global declarations"); auto program = p->program(); auto& ast = program.AST(); // Accept the enable directive although it cause an error - EXPECT_EQ(ast.Extensions(), - ast::ExtensionSet{ast::Enable::ExtensionKind::kInternalExtensionForTesting}); - EXPECT_EQ(ast.GlobalDeclarations().size(), 1u); + ASSERT_EQ(ast.Enables().size(), 1u); + auto* enable = ast.Enables()[0]; + EXPECT_EQ(enable->extension, ast::Extension::kF16); + ASSERT_EQ(ast.GlobalDeclarations().size(), 1u); + EXPECT_EQ(ast.GlobalDeclarations()[0], enable); } } // namespace
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc index 0e48b55..770d8d0 100644 --- a/src/tint/resolver/builtin_validation_test.cc +++ b/src/tint/resolver/builtin_validation_test.cc
@@ -378,10 +378,7 @@ TEST_F(ResolverDP4aExtensionValidationTest, Dot4I8PackedWithExtension) { // enable chromium_experimental_dp4a; // fn func { return dot4I8Packed(1u, 2u); } - auto* ext = - create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, - "chromium_experimental_dp4a"); - AST().AddEnable(ext); + Enable(ast::Extension::kChromiumExperimentalDP4a); Func("func", {}, ty.i32(), { @@ -409,10 +406,7 @@ TEST_F(ResolverDP4aExtensionValidationTest, Dot4U8PackedWithExtension) { // enable chromium_experimental_dp4a; // fn func { return dot4U8Packed(1u, 2u); } - auto* ext = - create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, - "chromium_experimental_dp4a"); - AST().AddEnable(ext); + Enable(ast::Extension::kChromiumExperimentalDP4a); Func("func", {}, ty.u32(), {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 4993cfe..b9c0833 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc
@@ -101,9 +101,6 @@ return false; } - // Create the semantic module - builder_->Sem().SetModule(builder_->create<sem::Module>(dependencies_.ordered_globals)); - bool result = ResolveInternal(); if (!result && !diagnostics_.contains_errors()) { @@ -111,6 +108,10 @@ return false; } + // Create the semantic module + builder_->Sem().SetModule(builder_->create<sem::Module>( + std::move(dependencies_.ordered_globals), std::move(enabled_extensions_))); + return result; } @@ -120,19 +121,16 @@ // Process all module-scope declarations in dependency order. for (auto* decl : dependencies_.ordered_globals) { Mark(decl); - // Enable directives don't have sem node. - if (decl->Is<ast::Enable>()) { - continue; - } - if (!Switch( + if (!Switch<bool>( decl, // + [&](const ast::Enable* e) { return Enable(e); }, [&](const ast::TypeDecl* td) { return TypeDecl(td); }, [&](const ast::Function* func) { return Function(func); }, [&](const ast::Variable* var) { return GlobalVariable(var); }, [&](Default) { TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled global declaration: " << decl->TypeInfo().name; - return nullptr; + return false; })) { return false; } @@ -146,8 +144,10 @@ return false; } - if (!AnalyzeUniformity(builder_, dependencies_)) { - // TODO(jrprice): Reject programs that fail uniformity analysis. + if (!enabled_extensions_.contains(ast::Extension::kChromiumDisableUniformityAnalysis)) { + if (!AnalyzeUniformity(builder_, dependencies_)) { + // TODO(jrprice): Reject programs that fail uniformity analysis. + } } bool result = true; @@ -174,7 +174,7 @@ [&](const ast::U32*) { return builder_->create<sem::U32>(); }, [&](const ast::F16* t) -> sem::F16* { // Validate if f16 type is allowed. - if (builder_->AST().Extensions().count(ast::Enable::ExtensionKind::kF16) == 0) { + if (!enabled_extensions_.contains(ast::Extension::kF16)) { AddError("f16 used without 'f16' extension enabled", t->source); return nullptr; } @@ -1358,7 +1358,7 @@ current_function_->AddDirectlyCalledBuiltin(builtin); - if (!validator_.RequiredExtensionForBuiltinFunction(call, builder_->AST().Extensions())) { + if (!validator_.RequiredExtensionForBuiltinFunction(call, enabled_extensions_)) { return nullptr; } @@ -1750,6 +1750,11 @@ return sem; } +bool Resolver::Enable(const ast::Enable* enable) { + enabled_extensions_.add(enable->extension); + return true; +} + sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { sem::Type* result = nullptr; if (auto* alias = named_type->As<ast::Alias>()) {
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 348c8e7..865c243 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h
@@ -228,6 +228,10 @@ /// @param ty the ast::Type sem::Type* Type(const ast::Type* ty); + /// @param enable the enable declaration + /// @returns the resolved extension + bool Enable(const ast::Enable* enable); + /// @param named_type the named type to resolve /// @returns the resolved semantic type sem::Type* TypeDecl(const ast::TypeDecl* named_type); @@ -351,6 +355,7 @@ DependencyGraph dependencies_; SemHelper sem_; Validator validator_; + ast::Extensions enabled_extensions_; std::vector<sem::Function*> entry_points_; std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_; std::unordered_set<const ast::Node*> marked_;
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc index 5f4617e..26b705f 100644 --- a/src/tint/resolver/type_validation_test.cc +++ b/src/tint/resolver/type_validation_test.cc
@@ -665,8 +665,8 @@ TEST_F(ResolverTypeValidationTest, F16TypeUsedWithExtension) { // enable f16; // var<private> v : f16; - auto* ext = create<ast::Enable>("f16"); - AST().AddEnable(ext); + Enable(ast::Extension::kF16); + Global("v", ty.f16(), ast::StorageClass::kPrivate); EXPECT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index 350e93b..273a07e 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc
@@ -1548,11 +1548,6 @@ } // namespace bool AnalyzeUniformity(ProgramBuilder* builder, const DependencyGraph& dependency_graph) { - if (builder->AST().Extensions().count( - ast::Enable::ExtensionKind::kChromiumDisableUniformityAnalysis)) { - return true; - } - UniformityGraph graph(builder); return graph.Build(dependency_graph); }
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index b278954..9f698e6 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc
@@ -1553,21 +1553,22 @@ check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3); } -bool Validator::RequiredExtensionForBuiltinFunction(const sem::Call* call, - const ast::ExtensionSet& extensionSet) const { +bool Validator::RequiredExtensionForBuiltinFunction( + const sem::Call* call, + const ast::Extensions& enabled_extensions) const { const auto* builtin = call->Target()->As<sem::Builtin>(); if (!builtin) { return true; } const auto extension = builtin->RequiredExtension(); - if (extension == ast::Enable::ExtensionKind::kNoExtension) { + if (extension == ast::Extension::kNone) { return true; } - if (extensionSet.find(extension) == extensionSet.cend()) { + if (!enabled_extensions.contains(extension)) { AddError("cannot call built-in function '" + std::string(builtin->str()) + - "' without extension " + ast::Enable::KindToName(extension), + "' without extension " + ast::str(extension), call->Declaration()->source); return false; }
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index 6efb543..a8c18d5 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h
@@ -361,10 +361,10 @@ /// Validates an optional builtin function and its required extension. /// @param call the builtin call to validate - /// @param extensionSet all the extensions declared in current module + /// @param enabled_extensions all the extensions declared in current module /// @returns true on success, false otherwise bool RequiredExtensionForBuiltinFunction(const sem::Call* call, - const ast::ExtensionSet& extensionSet) const; + const ast::Extensions& enabled_extensions) const; /// Validates there are no duplicate attributes /// @param attributes the list of attributes to validate
diff --git a/src/tint/sem/builtin.cc b/src/tint/sem/builtin.cc index faf451e..bb2878b 100644 --- a/src/tint/sem/builtin.cc +++ b/src/tint/sem/builtin.cc
@@ -153,11 +153,11 @@ return false; } -ast::Enable::ExtensionKind Builtin::RequiredExtension() const { +ast::Extension Builtin::RequiredExtension() const { if (IsDP4a()) { - return ast::Enable::ExtensionKind::kChromiumExperimentalDP4a; + return ast::Extension::kChromiumExperimentalDP4a; } - return ast::Enable::ExtensionKind::kNoExtension; + return ast::Extension::kNone; } } // namespace tint::sem
diff --git a/src/tint/sem/builtin.h b/src/tint/sem/builtin.h index 4752f16..1dc61ad 100644 --- a/src/tint/sem/builtin.h +++ b/src/tint/sem/builtin.h
@@ -18,6 +18,7 @@ #include <string> #include <vector> +#include "src/tint/ast/extension.h" #include "src/tint/sem/builtin_type.h" #include "src/tint/sem/call_target.h" #include "src/tint/sem/pipeline_stage_set.h" @@ -144,8 +145,8 @@ bool HasSideEffects() const; /// @returns the required extension of this builtin function. Returns - /// ast::Enable::ExtensionKind::kNoExtension if no extension is required. - ast::Enable::ExtensionKind RequiredExtension() const; + /// ast::Extension::kNone if no extension is required. + ast::Extension RequiredExtension() const; private: const BuiltinType type_;
diff --git a/src/tint/sem/module.cc b/src/tint/sem/module.cc index 83b7136..7c60650 100644 --- a/src/tint/sem/module.cc +++ b/src/tint/sem/module.cc
@@ -21,8 +21,8 @@ namespace tint::sem { -Module::Module(std::vector<const ast::Node*> dep_ordered_decls) - : dep_ordered_decls_(std::move(dep_ordered_decls)) {} +Module::Module(std::vector<const ast::Node*> dep_ordered_decls, ast::Extensions extensions) + : dep_ordered_decls_(std::move(dep_ordered_decls)), extensions_(std::move(extensions)) {} Module::~Module() = default;
diff --git a/src/tint/sem/module.h b/src/tint/sem/module.h index c265d4e..a7b3d45 100644 --- a/src/tint/sem/module.h +++ b/src/tint/sem/module.h
@@ -17,6 +17,7 @@ #include <vector> +#include "src/tint/ast/extension.h" #include "src/tint/sem/node.h" // Forward declarations @@ -33,7 +34,8 @@ public: /// Constructor /// @param dep_ordered_decls the dependency-ordered module-scope declarations - explicit Module(std::vector<const ast::Node*> dep_ordered_decls); + /// @param extensions the list of enabled extensions in the module + Module(std::vector<const ast::Node*> dep_ordered_decls, ast::Extensions extensions); /// Destructor ~Module() override; @@ -43,8 +45,12 @@ return dep_ordered_decls_; } + /// @returns the list of enabled extensions in the module + const ast::Extensions& Extensions() const { return extensions_; } + private: const std::vector<const ast::Node*> dep_ordered_decls_; + ast::Extensions extensions_; }; } // namespace tint::sem
diff --git a/src/tint/transform/disable_uniformity_analysis.cc b/src/tint/transform/disable_uniformity_analysis.cc index c025031..7a30023 100644 --- a/src/tint/transform/disable_uniformity_analysis.cc +++ b/src/tint/transform/disable_uniformity_analysis.cc
@@ -17,6 +17,7 @@ #include <utility> #include "src/tint/program_builder.h" +#include "src/tint/sem/module.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::DisableUniformityAnalysis); @@ -27,13 +28,12 @@ DisableUniformityAnalysis::~DisableUniformityAnalysis() = default; bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const { - return !program->AST().Extensions().count( - ast::Enable::ExtensionKind::kChromiumDisableUniformityAnalysis); + return !program->Sem().Module()->Extensions().contains( + ast::Extension::kChromiumDisableUniformityAnalysis); } void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - ctx.dst->AST().AddEnable(ctx.dst->create<ast::Enable>( - ast::Enable::KindToName(ast::Enable::ExtensionKind::kChromiumDisableUniformityAnalysis))); + ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis); ctx.Clone(); }
diff --git a/src/tint/writer/hlsl/generator_impl_builtin_test.cc b/src/tint/writer/hlsl/generator_impl_builtin_test.cc index 6c573ff..f6e0d98 100644 --- a/src/tint/writer/hlsl/generator_impl_builtin_test.cc +++ b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
@@ -727,10 +727,7 @@ } TEST_F(HlslGeneratorImplTest_Builtin, Dot4I8Packed) { - auto* ext = - create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, - "chromium_experimental_dp4a"); - AST().AddEnable(ext); + Enable(ast::Extension::kChromiumExperimentalDP4a); auto* val1 = Var("val1", ty.u32()); auto* val2 = Var("val2", ty.u32()); @@ -756,10 +753,7 @@ } TEST_F(HlslGeneratorImplTest_Builtin, Dot4U8Packed) { - auto* ext = - create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, - "chromium_experimental_dp4a"); - AST().AddEnable(ext); + Enable(ast::Extension::kChromiumExperimentalDP4a); auto* val1 = Var("val1", ty.u32()); auto* val2 = Var("val2", ty.u32());
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index aaa64ce..0b3118a 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc
@@ -256,7 +256,7 @@ push_memory_model(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical), U32Operand(SpvMemoryModelGLSL450)}); - for (auto ext : builder_.AST().Extensions()) { + for (auto ext : builder_.Sem().Module()->Extensions()) { GenerateExtension(ext); } @@ -366,7 +366,7 @@ } } -bool Builder::GenerateExtension(ast::Enable::ExtensionKind) { +bool Builder::GenerateExtension(ast::Extension) { /* For each supported extension, push corresponding capability into the builder. For example:
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h index 34cfd76..1745ed5 100644 --- a/src/tint/writer/spirv/builder.h +++ b/src/tint/writer/spirv/builder.h
@@ -224,11 +224,11 @@ ast::InterpolationType type, ast::InterpolationSampling sampling); - /// Generates a extension for the given extension kind. Emits an error and - /// returns false if the extension kind is not supported. - /// @param kind ExtensionKind of the extension to generate + /// Generates the enabling of an extension. Emits an error and returns false if the extension is + /// not supported. + /// @param ext the extension to generate /// @returns true on success. - bool GenerateExtension(ast::Enable::ExtensionKind kind); + bool GenerateExtension(ast::Extension ext); /// Generates a label for the given id. Emits an error and returns false if /// we're currently outside a function. /// @param id the id to use for the label
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 121d904..677421d 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -62,12 +62,12 @@ bool GeneratorImpl::Generate() { // Generate enable directives before any other global declarations. - for (auto ext : program_->AST().Extensions()) { - if (!EmitEnableDirective(ext)) { + for (auto enable : program_->AST().Enables()) { + if (!EmitEnable(enable)) { return false; } } - if (!program_->AST().Extensions().empty()) { + if (!program_->AST().Enables().empty()) { line(); } // Generate global declarations in the order they appear in the module. @@ -94,13 +94,9 @@ return true; } -bool GeneratorImpl::EmitEnableDirective(const ast::Enable::ExtensionKind ext) { +bool GeneratorImpl::EmitEnable(const ast::Enable* enable) { auto out = line(); - auto extension = ast::Enable::KindToName(ext); - if (extension == "") { - return false; - } - out << "enable " << extension << ";"; + out << "enable " << enable->extension << ";"; return true; }
diff --git a/src/tint/writer/wgsl/generator_impl.h b/src/tint/writer/wgsl/generator_impl.h index 8473b4f..a17e2da 100644 --- a/src/tint/writer/wgsl/generator_impl.h +++ b/src/tint/writer/wgsl/generator_impl.h
@@ -53,9 +53,9 @@ bool Generate(); /// Handles generating a enable directive - /// @param ext the extension kind in the enable directive to generate + /// @param enable the enable node /// @returns true if the enable directive was emitted - bool EmitEnableDirective(const ast::Enable::ExtensionKind ext); + bool EmitEnable(const ast::Enable* enable); /// Handles generating a declared type /// @param ty the declared type to generate /// @returns true if the declared type was emitted
diff --git a/src/tint/writer/wgsl/generator_impl_enable_test.cc b/src/tint/writer/wgsl/generator_impl_enable_test.cc index f9de371..503a9a0 100644 --- a/src/tint/writer/wgsl/generator_impl_enable_test.cc +++ b/src/tint/writer/wgsl/generator_impl_enable_test.cc
@@ -20,10 +20,12 @@ using WgslGeneratorImplTest = TestHelper; TEST_F(WgslGeneratorImplTest, Emit_Enable) { + auto* enable = Enable(ast::Extension::kF16); + GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitEnableDirective(ast::Enable::ExtensionKind::kInternalExtensionForTesting)); - EXPECT_EQ(gen.result(), R"(enable InternalExtensionForTesting; + ASSERT_TRUE(gen.EmitEnable(enable)); + EXPECT_EQ(gen.result(), R"(enable f16; )"); }