Import Tint changes from Dawn
Changes:
- a8e7cb73d418e79d9e39a760544038ed0229b46b Add Renamer transform to ShaderModuleGL and fixes by Shrek Shao <shrekshao@google.com>
- f629f749ebd42249b88dfbf379f4bdfef6f7c53d [tint] Make lang/*/validate APIs consistent by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: a8e7cb73d418e79d9e39a760544038ed0229b46b
Change-Id: Ib89e11a536647aa669618850aa62edee3fa7453b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/158681
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/cmd/remote_compile/main.cc b/src/tint/cmd/remote_compile/main.cc
index b14f4a9..1cade2d 100644
--- a/src/tint/cmd/remote_compile/main.cc
+++ b/src/tint/cmd/remote_compile/main.cc
@@ -37,7 +37,7 @@
#include <vector>
#if TINT_BUILD_MSL_WRITER
-#include "src/tint/lang/msl/validate/val.h"
+#include "src/tint/lang/msl/validate/validate.h"
#endif
#include "src/tint/utils/macros/compiler.h"
@@ -446,7 +446,7 @@
if (req.version_major == 2 && req.version_minor == 3) {
version = tint::msl::validate::MslVersion::kMsl_2_3;
}
- auto result = tint::msl::validate::UsingMetalAPI(req.source, version);
+ auto result = tint::msl::validate::ValidateUsingMetal(req.source, version);
CompileResponse resp;
if (result.failed) {
resp.error = result.output;
diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc
index 8f6a967..0eb15fe 100644
--- a/src/tint/cmd/tint/main.cc
+++ b/src/tint/cmd/tint/main.cc
@@ -82,12 +82,12 @@
#endif // TINT_BUILD_WGSL_WRITER
#if TINT_BUILD_MSL_WRITER
-#include "src/tint/lang/msl/validate/val.h"
+#include "src/tint/lang/msl/validate/validate.h"
#include "src/tint/lang/msl/writer/writer.h"
#endif // TINT_BUILD_MSL_WRITER
#if TINT_BUILD_HLSL_WRITER
-#include "src/tint/lang/hlsl/validate/val.h"
+#include "src/tint/lang/hlsl/validate/validate.h"
#include "src/tint/lang/hlsl/writer/writer.h"
#endif // TINT_BUILD_HLSL_WRITER
@@ -748,7 +748,7 @@
if (options.validate && options.skip_hash.count(hash) == 0) {
tint::msl::validate::Result res;
#ifdef __APPLE__
- res = tint::msl::validate::UsingMetalAPI(result->msl, msl_version);
+ res = tint::msl::validate::ValidateUsingMetal(result->msl, msl_version);
#else
#ifdef _WIN32
const char* default_xcrun_exe = "metal.exe";
@@ -758,7 +758,7 @@
auto xcrun = tint::Command::LookPath(
options.xcrun_path.empty() ? default_xcrun_exe : std::string(options.xcrun_path));
if (xcrun.Found()) {
- res = tint::msl::validate::Msl(xcrun.Path(), result->msl, msl_version);
+ res = tint::msl::validate::Validate(xcrun.Path(), result->msl, msl_version);
} else {
res.output = "xcrun executable not found. Cannot validate.";
res.failed = true;
@@ -826,7 +826,7 @@
}
}
- dxc_res = tint::hlsl::validate::UsingDXC(
+ dxc_res = tint::hlsl::validate::ValidateUsingDXC(
dxc.Path(), result->hlsl, result->entry_points, dxc_require_16bit_types);
} else if (must_validate_dxc) {
// DXC was explicitly requested. Error if it could not be found.
@@ -846,8 +846,8 @@
#ifdef _WIN32
if (fxc.Found()) {
fxc_found = true;
- fxc_res =
- tint::hlsl::validate::UsingFXC(fxc.Path(), result->hlsl, result->entry_points);
+ fxc_res = tint::hlsl::validate::ValidateUsingFXC(fxc.Path(), result->hlsl,
+ result->entry_points);
} else if (must_validate_fxc) {
// FXC was explicitly requested. Error if it could not be found.
fxc_res.failed = true;
diff --git a/src/tint/lang/hlsl/validate/BUILD.bazel b/src/tint/lang/hlsl/validate/BUILD.bazel
index 640a8bc..512dfd0 100644
--- a/src/tint/lang/hlsl/validate/BUILD.bazel
+++ b/src/tint/lang/hlsl/validate/BUILD.bazel
@@ -39,10 +39,10 @@
cc_library(
name = "validate",
srcs = [
- "hlsl.cc",
+ "validate.cc",
],
hdrs = [
- "val.h",
+ "validate.h",
],
deps = [
"//src/tint/lang/wgsl/ast",
diff --git a/src/tint/lang/hlsl/validate/BUILD.cmake b/src/tint/lang/hlsl/validate/BUILD.cmake
index 93e2f03..a16018e 100644
--- a/src/tint/lang/hlsl/validate/BUILD.cmake
+++ b/src/tint/lang/hlsl/validate/BUILD.cmake
@@ -41,8 +41,8 @@
# Condition: TINT_BUILD_HLSL_WRITER
################################################################################
tint_add_target(tint_lang_hlsl_validate lib
- lang/hlsl/validate/hlsl.cc
- lang/hlsl/validate/val.h
+ lang/hlsl/validate/validate.cc
+ lang/hlsl/validate/validate.h
)
tint_target_add_dependencies(tint_lang_hlsl_validate lib
diff --git a/src/tint/lang/hlsl/validate/BUILD.gn b/src/tint/lang/hlsl/validate/BUILD.gn
index b4f94d2..44d66c5 100644
--- a/src/tint/lang/hlsl/validate/BUILD.gn
+++ b/src/tint/lang/hlsl/validate/BUILD.gn
@@ -40,8 +40,8 @@
if (tint_build_hlsl_writer) {
libtint_source_set("validate") {
sources = [
- "hlsl.cc",
- "val.h",
+ "validate.cc",
+ "validate.h",
]
deps = [
"${tint_src_dir}/lang/wgsl/ast",
diff --git a/src/tint/lang/hlsl/validate/hlsl.cc b/src/tint/lang/hlsl/validate/validate.cc
similarity index 94%
rename from src/tint/lang/hlsl/validate/hlsl.cc
rename to src/tint/lang/hlsl/validate/validate.cc
index 5a164d3..e0cd85f 100644
--- a/src/tint/lang/hlsl/validate/hlsl.cc
+++ b/src/tint/lang/hlsl/validate/validate.cc
@@ -27,7 +27,7 @@
#include <string>
-#include "src/tint/lang/hlsl/validate/val.h"
+#include "src/tint/lang/hlsl/validate/validate.h"
#include "src/tint/utils/command/command.h"
#include "src/tint/utils/file/tmpfile.h"
@@ -44,10 +44,10 @@
namespace tint::hlsl::validate {
-Result UsingDXC(const std::string& dxc_path,
- const std::string& source,
- const EntryPointList& entry_points,
- bool require_16bit_types) {
+Result ValidateUsingDXC(const std::string& dxc_path,
+ const std::string& source,
+ const EntryPointList& entry_points,
+ bool require_16bit_types) {
Result result;
auto dxc = tint::Command(dxc_path);
@@ -121,9 +121,9 @@
}
#ifdef _WIN32
-Result UsingFXC(const std::string& fxc_path,
- const std::string& source,
- const EntryPointList& entry_points) {
+Result ValidateUsingFXC(const std::string& fxc_path,
+ const std::string& source,
+ const EntryPointList& entry_points) {
Result result;
// This library leaks if an error happens in this function, but it is ok
diff --git a/src/tint/lang/hlsl/validate/val.h b/src/tint/lang/hlsl/validate/validate.h
similarity index 83%
rename from src/tint/lang/hlsl/validate/val.h
rename to src/tint/lang/hlsl/validate/validate.h
index 80c4acc..766f203 100644
--- a/src/tint/lang/hlsl/validate/val.h
+++ b/src/tint/lang/hlsl/validate/validate.h
@@ -25,8 +25,8 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#ifndef SRC_TINT_LANG_HLSL_VALIDATE_VAL_H_
-#define SRC_TINT_LANG_HLSL_VALIDATE_VAL_H_
+#ifndef SRC_TINT_LANG_HLSL_VALIDATE_VALIDATE_H_
+#define SRC_TINT_LANG_HLSL_VALIDATE_VALIDATE_H_
#include <string>
#include <utility>
@@ -60,10 +60,10 @@
/// @param source the generated HLSL source
/// @param entry_points the list of entry points to validate
/// @return the result of the compile
-Result UsingDXC(const std::string& dxc_path,
- const std::string& source,
- const EntryPointList& entry_points,
- bool require_16bit_types);
+Result ValidateUsingDXC(const std::string& dxc_path,
+ const std::string& source,
+ const EntryPointList& entry_points,
+ bool require_16bit_types);
#ifdef _WIN32
/// Hlsl attempts to compile the shader with FXC, verifying that the shader
@@ -72,11 +72,11 @@
/// @param source the generated HLSL source
/// @param entry_points the list of entry points to validate
/// @return the result of the compile
-Result UsingFXC(const std::string& fxc_path,
- const std::string& source,
- const EntryPointList& entry_points);
+Result ValidateUsingFXC(const std::string& fxc_path,
+ const std::string& source,
+ const EntryPointList& entry_points);
#endif // _WIN32
} // namespace tint::hlsl::validate
-#endif // SRC_TINT_LANG_HLSL_VALIDATE_VAL_H_
+#endif // SRC_TINT_LANG_HLSL_VALIDATE_VALIDATE_H_
diff --git a/src/tint/lang/msl/validate/BUILD.bazel b/src/tint/lang/msl/validate/BUILD.bazel
index d27091a..5ea0d1a 100644
--- a/src/tint/lang/msl/validate/BUILD.bazel
+++ b/src/tint/lang/msl/validate/BUILD.bazel
@@ -39,15 +39,15 @@
cc_library(
name = "validate",
srcs = [
- "msl.cc",
+ "validate.cc",
] + select({
":is_mac": [
- "msl_metal.mm",
+ "validate_metal.mm",
],
"//conditions:default": [],
}),
hdrs = [
- "val.h",
+ "validate.h",
],
deps = [
"//src/tint/lang/core",
diff --git a/src/tint/lang/msl/validate/BUILD.cmake b/src/tint/lang/msl/validate/BUILD.cmake
index 2ccef7d..14012fe 100644
--- a/src/tint/lang/msl/validate/BUILD.cmake
+++ b/src/tint/lang/msl/validate/BUILD.cmake
@@ -41,8 +41,8 @@
# Condition: TINT_BUILD_MSL_WRITER
################################################################################
tint_add_target(tint_lang_msl_validate lib
- lang/msl/validate/msl.cc
- lang/msl/validate/val.h
+ lang/msl/validate/validate.cc
+ lang/msl/validate/validate.h
)
tint_target_add_dependencies(tint_lang_msl_validate lib
@@ -71,7 +71,7 @@
if(IS_MAC)
tint_target_add_sources(tint_lang_msl_validate lib
- "lang/msl/validate/msl_metal.mm"
+ "lang/msl/validate/validate_metal.mm"
)
tint_target_add_external_dependencies(tint_lang_msl_validate lib
"metal"
diff --git a/src/tint/lang/msl/validate/BUILD.gn b/src/tint/lang/msl/validate/BUILD.gn
index ac8a8f4..5789793 100644
--- a/src/tint/lang/msl/validate/BUILD.gn
+++ b/src/tint/lang/msl/validate/BUILD.gn
@@ -40,8 +40,8 @@
if (tint_build_msl_writer) {
libtint_source_set("validate") {
sources = [
- "msl.cc",
- "val.h",
+ "validate.cc",
+ "validate.h",
]
deps = [
"${tint_src_dir}/lang/core",
@@ -68,7 +68,7 @@
]
if (is_mac) {
- sources += [ "msl_metal.mm" ]
+ sources += [ "validate_metal.mm" ]
deps += [ "${tint_src_dir}:metal" ]
}
}
diff --git a/src/tint/lang/msl/validate/msl.cc b/src/tint/lang/msl/validate/validate.cc
similarity index 95%
rename from src/tint/lang/msl/validate/msl.cc
rename to src/tint/lang/msl/validate/validate.cc
index 35be180..2230f54 100644
--- a/src/tint/lang/msl/validate/msl.cc
+++ b/src/tint/lang/msl/validate/validate.cc
@@ -25,7 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#include "src/tint/lang/msl/validate/val.h"
+#include "src/tint/lang/msl/validate/validate.h"
#include "src/tint/lang/wgsl/ast/module.h"
#include "src/tint/lang/wgsl/program/program.h"
@@ -34,7 +34,7 @@
namespace tint::msl::validate {
-Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version) {
+Result Validate(const std::string& xcrun_path, const std::string& source, MslVersion version) {
Result result;
auto xcrun = tint::Command(xcrun_path);
diff --git a/src/tint/lang/msl/validate/val.h b/src/tint/lang/msl/validate/validate.h
similarity index 89%
rename from src/tint/lang/msl/validate/val.h
rename to src/tint/lang/msl/validate/validate.h
index 594f9ba..89581d6 100644
--- a/src/tint/lang/msl/validate/val.h
+++ b/src/tint/lang/msl/validate/validate.h
@@ -25,8 +25,8 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#ifndef SRC_TINT_LANG_MSL_VALIDATE_VAL_H_
-#define SRC_TINT_LANG_MSL_VALIDATE_VAL_H_
+#ifndef SRC_TINT_LANG_MSL_VALIDATE_VALIDATE_H_
+#define SRC_TINT_LANG_MSL_VALIDATE_VALIDATE_H_
#include <string>
#include <utility>
@@ -67,7 +67,7 @@
/// @param source the generated MSL source
/// @param version the version of MSL to validate against
/// @return the result of the compile
-Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version);
+Result Validate(const std::string& xcrun_path, const std::string& source, MslVersion version);
#ifdef __APPLE__
/// Msl attempts to compile the shader with the runtime Metal Shader Compiler
@@ -75,9 +75,9 @@
/// @param source the generated MSL source
/// @param version the version of MSL to validate against
/// @return the result of the compile
-Result UsingMetalAPI(const std::string& source, MslVersion version);
+Result ValidateUsingMetal(const std::string& source, MslVersion version);
#endif // __APPLE__
} // namespace tint::msl::validate
-#endif // SRC_TINT_LANG_MSL_VALIDATE_VAL_H_
+#endif // SRC_TINT_LANG_MSL_VALIDATE_VALIDATE_H_
diff --git a/src/tint/lang/msl/validate/msl_metal.mm b/src/tint/lang/msl/validate/validate_metal.mm
similarity index 95%
rename from src/tint/lang/msl/validate/msl_metal.mm
rename to src/tint/lang/msl/validate/validate_metal.mm
index 2b56ea0..4669e6c 100644
--- a/src/tint/lang/msl/validate/msl_metal.mm
+++ b/src/tint/lang/msl/validate/validate_metal.mm
@@ -29,11 +29,11 @@
#import <Metal/Metal.h>
-#include "src/tint/lang/msl/validate/val.h"
+#include "src/tint/lang/msl/validate/validate.h"
namespace tint::msl::validate {
-Result UsingMetalAPI(const std::string& src, MslVersion version) {
+Result ValidateUsingMetal(const std::string& src, MslVersion version) {
Result result;
NSError* error = nil;
diff --git a/src/tint/lang/wgsl/ast/transform/renamer.cc b/src/tint/lang/wgsl/ast/transform/renamer.cc
index 89dc5b4..0183f76 100644
--- a/src/tint/lang/wgsl/ast/transform/renamer.cc
+++ b/src/tint/lang/wgsl/ast/transform/renamer.cc
@@ -1265,6 +1265,8 @@
Renamer::Data::~Data() = default;
Renamer::Config::Config(Target t, bool pu) : target(t), preserve_unicode(pu) {}
+Renamer::Config::Config(Target t, bool pu, Remappings&& remappings)
+ : target(t), preserve_unicode(pu), requested_names(std::move(remappings)) {}
Renamer::Config::Config(const Config&) = default;
Renamer::Config::~Config() = default;
@@ -1342,10 +1344,12 @@
Target target = Target::kAll;
bool preserve_unicode = false;
+ const Remappings* requested_names = nullptr;
if (auto* cfg = inputs.Get<Config>()) {
target = cfg->target;
preserve_unicode = cfg->preserve_unicode;
+ requested_names = &(cfg->requested_names);
}
// Returns true if the symbol should be renamed based on the input configuration settings.
@@ -1394,7 +1398,17 @@
}
// Create a replacement for this symbol, if we haven't already.
- auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
+ auto replacement = remappings.GetOrCreate(symbol, [&] {
+ if (requested_names) {
+ auto iter = requested_names->find(symbol.Name());
+ if (iter != requested_names->end()) {
+ // Use the explicitly given name for renaming this symbol
+ // if the extra is given in the config.
+ return b.Symbols().New(iter->second);
+ }
+ }
+ return b.Symbols().New();
+ });
// Reconstruct the identifier
if (auto* tmpl_ident = ident->As<TemplatedIdentifier>()) {
@@ -1407,7 +1421,7 @@
ctx.Clone();
- Data::Remappings out;
+ Remappings out;
for (auto it : remappings) {
out[it.key.Name()] = it.value.Name();
}
diff --git a/src/tint/lang/wgsl/ast/transform/renamer.h b/src/tint/lang/wgsl/ast/transform/renamer.h
index 0e32df2..3a390c3 100644
--- a/src/tint/lang/wgsl/ast/transform/renamer.h
+++ b/src/tint/lang/wgsl/ast/transform/renamer.h
@@ -38,12 +38,12 @@
/// Renamer is a Transform that renames all the symbols in a program.
class Renamer final : public Castable<Renamer, Transform> {
public:
+ /// Remappings is a map of old symbol name to new symbol name
+ using Remappings = std::unordered_map<std::string, std::string>;
+
/// Data is outputted by the Renamer transform.
/// Data holds information about shader usage and constant buffer offsets.
struct Data final : public Castable<Data, transform::Data> {
- /// Remappings is a map of old symbol name to new symbol name
- using Remappings = std::unordered_map<std::string, std::string>;
-
/// Constructor
/// @param remappings the symbol remappings
explicit Data(Remappings&& remappings);
@@ -79,6 +79,12 @@
/// renamed
explicit Config(Target tgt, bool keep_unicode = false);
+ /// Constructor
+ /// @param tgt the targets to rename
+ /// @param keep_unicode if false, symbols with non-ascii code-points are renamed
+ /// @param remappings requested old to new name map
+ explicit Config(Target tgt, bool keep_unicode, Remappings&& remappings);
+
/// Copy constructor
Config(const Config&);
@@ -90,6 +96,9 @@
/// If false, symbols with non-ascii code-points are renamed.
bool preserve_unicode = false;
+
+ /// Requested renaming rules
+ const Remappings requested_names = {};
};
/// Constructor using a the configuration provided in the input Data
diff --git a/src/tint/lang/wgsl/ast/transform/renamer_test.cc b/src/tint/lang/wgsl/ast/transform/renamer_test.cc
index 2969f22..d31f0a6 100644
--- a/src/tint/lang/wgsl/ast/transform/renamer_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/renamer_test.cc
@@ -95,7 +95,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"vert_idx", "tint_symbol_1"},
{"test", "tint_symbol"},
{"entry", "tint_symbol_2"},
@@ -103,6 +103,74 @@
EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
}
+TEST_F(RenamerTest, RequestedNames) {
+ auto* src = R"(
+struct ShaderIO {
+ @location(1) var1: f32,
+ @location(3) @interpolate(flat) var3: u32,
+ @builtin(position) pos: vec4f,
+}
+
+@vertex fn main(@builtin(vertex_index) vert_idx : u32)
+ -> ShaderIO {
+ var pos = array(
+ vec2f(-1.0, 3.0),
+ vec2f(-1.0, -3.0),
+ vec2f(3.0, 0.0));
+
+ var shaderIO: ShaderIO;
+ shaderIO.var1 = 0.0;
+ shaderIO.var3 = 1u;
+ shaderIO.pos = vec4f(pos[vert_idx], 0.0, 1.0);
+
+ return shaderIO;
+}
+)";
+
+ auto* expect = R"(
+struct tint_symbol {
+ @location(1)
+ user_var1 : f32,
+ @location(3) @interpolate(flat)
+ user_var3 : u32,
+ @builtin(position)
+ tint_symbol_1 : vec4f,
+}
+
+@vertex
+fn tint_symbol_2(@builtin(vertex_index) tint_symbol_3 : u32) -> tint_symbol {
+ var tint_symbol_1 = array(vec2f(-(1.0), 3.0), vec2f(-(1.0), -(3.0)), vec2f(3.0, 0.0));
+ var tint_symbol_4 : tint_symbol;
+ tint_symbol_4.user_var1 = 0.0;
+ tint_symbol_4.user_var3 = 1u;
+ tint_symbol_4.tint_symbol_1 = vec4f(tint_symbol_1[tint_symbol_3], 0.0, 1.0);
+ return tint_symbol_4;
+}
+)";
+
+ DataMap inputs;
+ inputs.Add<Renamer::Config>(Renamer::Target::kAll,
+ /* preserve_unicode */ false,
+ /* remappings */
+ Renamer::Remappings{
+ {"var1", "user_var1"},
+ {"var3", "user_var3"},
+ });
+ auto got = Run<Renamer>(src, inputs);
+
+ EXPECT_EQ(expect, str(got));
+
+ auto* data = got.data.Get<Renamer::Data>();
+
+ ASSERT_NE(data, nullptr);
+ Renamer::Remappings expected_remappings = {
+ {"pos", "tint_symbol_1"}, {"vert_idx", "tint_symbol_3"}, {"ShaderIO", "tint_symbol"},
+ {"shaderIO", "tint_symbol_4"}, {"main", "tint_symbol_2"}, {"var1", "user_var1"},
+ {"var3", "user_var3"},
+ };
+ EXPECT_THAT(data->remappings, ContainerEq(expected_remappings));
+}
+
TEST_F(RenamerTest, PreserveSwizzles) {
auto* src = R"(
@vertex
@@ -133,7 +201,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"entry", "tint_symbol"}, {"v", "tint_symbol_1"}, {"rgba", "tint_symbol_2"},
{"xyzw", "tint_symbol_3"}, {"z", "tint_symbol_4"},
};
@@ -164,7 +232,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"entry", "tint_symbol"},
{"blah", "tint_symbol_1"},
};
@@ -199,7 +267,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"entry", "tint_symbol"}, {"a", "tint_symbol_1"}, {"b", "tint_symbol_2"},
{"c", "tint_symbol_3"}, {"d", "tint_symbol_4"},
};
@@ -241,7 +309,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"entry", "tint_symbol"},
{"value", "tint_symbol_1"},
};
@@ -319,7 +387,7 @@
auto* data = got.data.Get<Renamer::Data>();
ASSERT_NE(data, nullptr);
- Renamer::Data::Remappings expected_remappings = {
+ Renamer::Remappings expected_remappings = {
{"entry", "tint_symbol"},
{"tint_symbol", "tint_symbol_1"},
{"tint_symbol_2", "tint_symbol_2"},
diff --git a/src/tint/lang/wgsl/inspector/entry_point.cc b/src/tint/lang/wgsl/inspector/entry_point.cc
index e28ce54..080f418 100644
--- a/src/tint/lang/wgsl/inspector/entry_point.cc
+++ b/src/tint/lang/wgsl/inspector/entry_point.cc
@@ -32,6 +32,7 @@
StageVariable::StageVariable() = default;
StageVariable::StageVariable(const StageVariable& other)
: name(other.name),
+ variable_name(other.variable_name),
has_location_attribute(other.has_location_attribute),
location_attribute(other.location_attribute),
component_type(other.component_type),
diff --git a/src/tint/lang/wgsl/inspector/entry_point.h b/src/tint/lang/wgsl/inspector/entry_point.h
index e8f19e2..1dee5c1 100644
--- a/src/tint/lang/wgsl/inspector/entry_point.h
+++ b/src/tint/lang/wgsl/inspector/entry_point.h
@@ -82,8 +82,10 @@
/// Destructor
~StageVariable();
- /// Name of the variable in the shader.
+ /// Name of the variable in the shader. (including struct nested accessing, e.g. 'struct.var')
std::string name;
+ /// Name of the variable itself. (e.g. 'var')
+ std::string variable_name;
/// Is location attribute present
bool has_location_attribute = false;
/// Value of the location attribute, only valid if #has_location_attribute is
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index aa9bcfa..278773a 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -175,7 +175,8 @@
}
for (auto* param : sem->Parameters()) {
- AddEntryPointInOutVariables(param->Declaration()->name->symbol.Name(), param->Type(),
+ AddEntryPointInOutVariables(param->Declaration()->name->symbol.Name(),
+ param->Declaration()->name->symbol.Name(), param->Type(),
param->Declaration()->attributes, param->Location(),
entry_point.input_variables);
@@ -196,7 +197,7 @@
}
if (!sem->ReturnType()->Is<core::type::Void>()) {
- AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
+ AddEntryPointInOutVariables("<retval>", "", sem->ReturnType(), func->return_type_attributes,
sem->ReturnLocation(), entry_point.output_variables);
entry_point.output_sample_mask_used = ContainsBuiltin(
@@ -575,6 +576,7 @@
}
void Inspector::AddEntryPointInOutVariables(std::string name,
+ std::string variable_name,
const core::type::Type* type,
VectorRef<const ast::Attribute*> attributes,
std::optional<uint32_t> location,
@@ -589,8 +591,8 @@
if (auto* struct_ty = unwrapped_type->As<sem::Struct>()) {
// Recurse into members.
for (auto* member : struct_ty->Members()) {
- AddEntryPointInOutVariables(name + "." + member->Name().Name(), member->Type(),
- member->Declaration()->attributes,
+ AddEntryPointInOutVariables(name + "." + member->Name().Name(), member->Name().Name(),
+ member->Type(), member->Declaration()->attributes,
member->Attributes().location, variables);
}
return;
@@ -600,6 +602,7 @@
StageVariable stage_variable;
stage_variable.name = name;
+ stage_variable.variable_name = variable_name;
std::tie(stage_variable.component_type, stage_variable.composition_type) =
CalculateComponentAndComposition(type);
diff --git a/src/tint/lang/wgsl/inspector/inspector.h b/src/tint/lang/wgsl/inspector/inspector.h
index 9b12ba6..1af117d 100644
--- a/src/tint/lang/wgsl/inspector/inspector.h
+++ b/src/tint/lang/wgsl/inspector/inspector.h
@@ -170,12 +170,14 @@
/// Recursively add entry point IO variables.
/// If `type` is a struct, recurse into members, appending the member name.
/// Otherwise, add the variable unless it is a builtin.
- /// @param name the name of the variable being added
+ /// @param name the name of the variable being added, including struct nested accessings.
+ /// @param variable_name the name of the variable being added
/// @param type the type of the variable
/// @param attributes the variable attributes
/// @param location the location value if provided
/// @param variables the list to add the variables to
void AddEntryPointInOutVariables(std::string name,
+ std::string variable_name,
const core::type::Type* type,
VectorRef<const ast::Attribute*> attributes,
std::optional<uint32_t> location,
diff --git a/src/tint/lang/wgsl/inspector/inspector_test.cc b/src/tint/lang/wgsl/inspector/inspector_test.cc
index 009fbdc..80435da 100644
--- a/src/tint/lang/wgsl/inspector/inspector_test.cc
+++ b/src/tint/lang/wgsl/inspector/inspector_test.cc
@@ -442,12 +442,14 @@
ASSERT_EQ(1u, result[0].input_variables.size());
EXPECT_EQ("in_var", result[0].input_variables[0].name);
+ EXPECT_EQ("in_var", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(component, result[0].input_variables[0].component_type);
ASSERT_EQ(1u, result[0].output_variables.size());
EXPECT_EQ("<retval>", result[0].output_variables[0].name);
+ EXPECT_EQ("", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(component, result[0].output_variables[0].component_type);
@@ -498,16 +500,19 @@
ASSERT_EQ(3u, result[0].input_variables.size());
EXPECT_EQ("in_var0", result[0].input_variables[0].name);
+ EXPECT_EQ("in_var0", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[0].interpolation_type);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("in_var1", result[0].input_variables[1].name);
+ EXPECT_EQ("in_var1", result[0].input_variables[1].variable_name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[1].interpolation_type);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
EXPECT_EQ("in_var4", result[0].input_variables[2].name);
+ EXPECT_EQ("in_var4", result[0].input_variables[2].variable_name);
EXPECT_TRUE(result[0].input_variables[2].has_location_attribute);
EXPECT_EQ(4u, result[0].input_variables[2].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[2].interpolation_type);
@@ -515,6 +520,7 @@
ASSERT_EQ(1u, result[0].output_variables.size());
EXPECT_EQ("<retval>", result[0].output_variables[0].name);
+ EXPECT_EQ("", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
@@ -562,6 +568,7 @@
ASSERT_EQ(1u, result[0].input_variables.size());
EXPECT_EQ("in_var_foo", result[0].input_variables[0].name);
+ EXPECT_EQ("in_var_foo", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[0].input_variables[0].interpolation_type);
@@ -569,12 +576,14 @@
ASSERT_EQ(1u, result[0].output_variables.size());
EXPECT_EQ("<retval>", result[0].output_variables[0].name);
+ EXPECT_EQ("", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
ASSERT_EQ(1u, result[1].input_variables.size());
EXPECT_EQ("in_var_bar", result[1].input_variables[0].name);
+ EXPECT_EQ("in_var_bar", result[1].input_variables[0].variable_name);
EXPECT_TRUE(result[1].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[1].input_variables[0].location_attribute);
EXPECT_EQ(InterpolationType::kFlat, result[1].input_variables[0].interpolation_type);
@@ -582,6 +591,7 @@
ASSERT_EQ(1u, result[1].output_variables.size());
EXPECT_EQ("<retval>", result[1].output_variables[0].name);
+ EXPECT_EQ("", result[1].output_variables[0].variable_name);
EXPECT_TRUE(result[1].output_variables[0].has_location_attribute);
EXPECT_EQ(1u, result[1].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[1].output_variables[0].component_type);
@@ -615,6 +625,7 @@
ASSERT_EQ(1u, result[0].input_variables.size());
EXPECT_EQ("in_var1", result[0].input_variables[0].name);
+ EXPECT_EQ("in_var1", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kF32, result[0].input_variables[0].component_type);
@@ -647,20 +658,24 @@
ASSERT_EQ(2u, result[0].input_variables.size());
EXPECT_EQ("param.a", result[0].input_variables[0].name);
+ EXPECT_EQ("a", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("param.b", result[0].input_variables[1].name);
+ EXPECT_EQ("b", result[0].input_variables[1].variable_name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
ASSERT_EQ(2u, result[0].output_variables.size());
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
+ EXPECT_EQ("a", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
+ EXPECT_EQ("b", result[0].output_variables[1].variable_name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);
@@ -693,20 +708,24 @@
ASSERT_EQ(2u, result[0].output_variables.size());
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
+ EXPECT_EQ("a", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
+ EXPECT_EQ("b", result[0].output_variables[1].variable_name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);
ASSERT_EQ(2u, result[1].input_variables.size());
EXPECT_EQ("param.a", result[1].input_variables[0].name);
+ EXPECT_EQ("a", result[1].input_variables[0].variable_name);
EXPECT_TRUE(result[1].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[1].input_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[1].input_variables[0].component_type);
EXPECT_EQ("param.b", result[1].input_variables[1].name);
+ EXPECT_EQ("b", result[1].input_variables[1].variable_name);
EXPECT_TRUE(result[1].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[1].input_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[1].input_variables[1].component_type);
@@ -745,32 +764,39 @@
ASSERT_EQ(5u, result[0].input_variables.size());
EXPECT_EQ("param_a.a", result[0].input_variables[0].name);
+ EXPECT_EQ("a", result[0].input_variables[0].variable_name);
EXPECT_TRUE(result[0].input_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].input_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[0].component_type);
EXPECT_EQ("param_a.b", result[0].input_variables[1].name);
+ EXPECT_EQ("b", result[0].input_variables[1].variable_name);
EXPECT_TRUE(result[0].input_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].input_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[1].component_type);
EXPECT_EQ("param_b.a", result[0].input_variables[2].name);
+ EXPECT_EQ("a", result[0].input_variables[2].variable_name);
EXPECT_TRUE(result[0].input_variables[2].has_location_attribute);
EXPECT_EQ(2u, result[0].input_variables[2].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].input_variables[2].component_type);
EXPECT_EQ("param_c", result[0].input_variables[3].name);
+ EXPECT_EQ("param_c", result[0].input_variables[3].variable_name);
EXPECT_TRUE(result[0].input_variables[3].has_location_attribute);
EXPECT_EQ(3u, result[0].input_variables[3].location_attribute);
EXPECT_EQ(ComponentType::kF32, result[0].input_variables[3].component_type);
EXPECT_EQ("param_d", result[0].input_variables[4].name);
+ EXPECT_EQ("param_d", result[0].input_variables[4].variable_name);
EXPECT_TRUE(result[0].input_variables[4].has_location_attribute);
EXPECT_EQ(4u, result[0].input_variables[4].location_attribute);
EXPECT_EQ(ComponentType::kF32, result[0].input_variables[4].component_type);
ASSERT_EQ(2u, result[0].output_variables.size());
EXPECT_EQ("<retval>.a", result[0].output_variables[0].name);
+ EXPECT_EQ("a", result[0].output_variables[0].variable_name);
EXPECT_TRUE(result[0].output_variables[0].has_location_attribute);
EXPECT_EQ(0u, result[0].output_variables[0].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[0].component_type);
EXPECT_EQ("<retval>.b", result[0].output_variables[1].name);
+ EXPECT_EQ("b", result[0].output_variables[1].variable_name);
EXPECT_TRUE(result[0].output_variables[1].has_location_attribute);
EXPECT_EQ(1u, result[0].output_variables[1].location_attribute);
EXPECT_EQ(ComponentType::kU32, result[0].output_variables[1].component_type);