[spirv-reader]: entry point names are WGSL identifiers
Bug: tint:233, tint:3
Change-Id: Ib753c47c4a77b852e5065c540da79d8cebe6a100
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28282
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 7771643..5b5b2f0 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -17,6 +17,7 @@
#include <cassert>
#include <cstring>
#include <limits>
+#include <locale>
#include <memory>
#include <string>
#include <unordered_map>
@@ -608,6 +609,22 @@
return true;
}
+bool ParserImpl::IsValidIdentifier(const std::string& str) {
+ if (str.empty()) {
+ return false;
+ }
+ std::locale c_locale("C");
+ if (!std::isalpha(str[0], c_locale)) {
+ return false;
+ }
+ for (const char& ch : str) {
+ if ((ch != '_') && !std::isalnum(ch, c_locale)) {
+ return false;
+ }
+ }
+ return true;
+}
+
bool ParserImpl::EmitEntryPoints() {
for (const spvtools::opt::Instruction& entry_point :
module_->entry_points()) {
@@ -616,6 +633,11 @@
const std::string ep_name = entry_point.GetOperand(2).AsString();
const std::string name = namer_.GetName(function_id);
+ if (!IsValidIdentifier(ep_name)) {
+ return Fail() << "entry point name is not a valid WGSL identifier: "
+ << ep_name;
+ }
+
ast_module_.AddEntryPoint(std::make_unique<ast::EntryPoint>(
enum_converter_.ToPipelineStage(stage), ep_name, name));
}
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index deeed0a..649bdc0 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -372,6 +372,10 @@
/// @return the Source record, or a default one
Source GetSourceForInst(const spvtools::opt::Instruction* inst) const;
+ /// @param str a candidate identifier
+ /// @returns true if the given string is a valid WGSL identifier.
+ static bool IsValidIdentifier(const std::string& str);
+
private:
/// Converts a specific SPIR-V type to a Tint type. Integer case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
diff --git a/src/reader/spirv/parser_impl_entry_point_test.cc b/src/reader/spirv/parser_impl_entry_point_test.cc
index 30dbebc..507e5f1 100644
--- a/src/reader/spirv/parser_impl_entry_point_test.cc
+++ b/src/reader/spirv/parser_impl_entry_point_test.cc
@@ -24,6 +24,7 @@
namespace spirv {
namespace {
+using ::testing::Eq;
using ::testing::HasSubstr;
std::string MakeEntryPoint(const std::string& stage,
@@ -81,13 +82,11 @@
HasSubstr(R"(EntryPoint{fragment as work = work_2})"));
}
-TEST_F(SpvParserTest, EntryPoint_NameIsSanitized) {
+TEST_F(SpvParserTest, EntryPoint_MustBeWgslIdentifier) {
auto* p = parser(test::Assemble(MakeEntryPoint("GLCompute", ".1234")));
- EXPECT_TRUE(p->BuildAndParseInternalModule());
- EXPECT_TRUE(p->error().empty());
- const auto module_str = p->module().to_str();
- EXPECT_THAT(module_str,
- HasSubstr(R"(EntryPoint{compute as .1234 = x_1234})"));
+ EXPECT_FALSE(p->BuildAndParseInternalModule());
+ EXPECT_THAT(p->error(),
+ Eq("entry point name is not a valid WGSL identifier: .1234"));
}
} // namespace
diff --git a/src/reader/spirv/parser_impl_test.cc b/src/reader/spirv/parser_impl_test.cc
index d1facdc..fa35918 100644
--- a/src/reader/spirv/parser_impl_test.cc
+++ b/src/reader/spirv/parser_impl_test.cc
@@ -205,6 +205,27 @@
EXPECT_EQ(0u, s99.column);
}
+TEST_F(SpvParserTest, Impl_IsValidIdentifier) {
+ EXPECT_FALSE(ParserImpl::IsValidIdentifier("")); // empty
+ EXPECT_FALSE(
+ ParserImpl::IsValidIdentifier("_")); // leading underscore, but ok later
+ EXPECT_FALSE(
+ ParserImpl::IsValidIdentifier("9")); // leading digit, but ok later
+ EXPECT_FALSE(ParserImpl::IsValidIdentifier(" ")); // leading space
+ EXPECT_FALSE(ParserImpl::IsValidIdentifier("a ")); // trailing space
+ EXPECT_FALSE(ParserImpl::IsValidIdentifier("a 1")); // space in the middle
+ EXPECT_FALSE(ParserImpl::IsValidIdentifier(".")); // weird character
+
+ // a simple identifier
+ EXPECT_TRUE(ParserImpl::IsValidIdentifier("A"));
+ // each upper case letter
+ EXPECT_TRUE(ParserImpl::IsValidIdentifier("ABCDEFGHIJKLMNOPQRSTUVWXYZ"));
+ // each lower case letter
+ EXPECT_TRUE(ParserImpl::IsValidIdentifier("abcdefghijklmnopqrstuvwxyz"));
+ EXPECT_TRUE(ParserImpl::IsValidIdentifier("a0123456789")); // each digit
+ EXPECT_TRUE(ParserImpl::IsValidIdentifier("x_")); // has underscore
+}
+
} // namespace
} // namespace spirv
} // namespace reader