Import Tint changes from Dawn
Changes:
- 1145db63263f759dcf9ce7fec522155a57974b42 [tint][ir][spriv-writer] Add tests for switch phis by Ben Clayton <bclayton@google.com>
- dd1a46a3bd0fbcda2f470cde4abc1704995b46d5 [tint][cmd] Fix TINT_BUILD_SYNTAX_TREE_WRITER build by Ben Clayton <bclayton@google.com>
- 377b4ae0b9b311da647b0a83ec63baf21ae39b51 [tint][ir] Rename Loop::Start to Loop::Body by Ben Clayton <bclayton@google.com>
- 7cb5fc8c2da91ae5e0257678bbdb5bd3563ec5bd [tint][ir] Add args parameter to all branch instructions by Ben Clayton <bclayton@google.com>
- 576f199a4e8d5c264687d01ecfdd97dba1025927 [tint][ir][spirv-writer] Emit phi nodes by Ben Clayton <bclayton@google.com>
- 781f5bb11563ee77f217f45620553f3ea8fec552 [tint][resolver]: Don't use recursion for constant conver... by Ben Clayton <bclayton@google.com>
- 3e2119d7a2887954433c8f6706ae4f818baae01e [ir][spirv-writer] Emit array types by James Price <jrprice@google.com>
- 3ee81bbacfb2833c2335dcb1ada8a9af3d3ef3a9 [tint][utils]: Add utilities for command line flag parsing by Ben Clayton <bclayton@google.com>
- dced7533133b33644994e507a7a4c7cb74a677f1 [ir][spirv-writer] Emit workgroup variables by James Price <jrprice@google.com>
- 5e2130a87e3aa4fc2063ba5bc9c069c96f6de12d [ir][spirv-writer] Emit private variables by James Price <jrprice@google.com>
- ad5a0e20560cebbdf9d54b3fc6208093358f0256 d3d: handle vertex_index and instance_index separately by Peng Huang <penghuang@chromium.org>
GitOrigin-RevId: 1145db63263f759dcf9ce7fec522155a57974b42
Change-Id: I3a65e734997571940f77fd0917a93fefe96bd374
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/135480
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 83fcbf5..ce779bc 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -237,6 +237,8 @@
"utils/bump_allocator.h",
"utils/castable.cc",
"utils/castable.h",
+ "utils/cli.cc",
+ "utils/cli.h",
"utils/compiler_macros.h",
"utils/concat.h",
"utils/crc32.h",
@@ -1822,6 +1824,7 @@
"utils/block_allocator_test.cc",
"utils/bump_allocator_test.cc",
"utils/castable_test.cc",
+ "utils/cli_test.cc",
"utils/crc32_test.cc",
"utils/defer_test.cc",
"utils/enum_set_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index c20c850..2ecb97a 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -525,6 +525,8 @@
utils/bump_allocator.h
utils/castable.cc
utils/castable.h
+ utils/cli.cc
+ utils/cli.h
utils/compiler_macros.h
utils/concat.h
utils/crc32.h
@@ -1048,6 +1050,7 @@
utils/bitset_test.cc
utils/block_allocator_test.cc
utils/bump_allocator_test.cc
+ utils/cli_test.cc
utils/castable_test.cc
utils/crc32_test.cc
utils/defer_test.cc
diff --git a/src/tint/ast/transform/first_index_offset.cc b/src/tint/ast/transform/first_index_offset.cc
index 4c77efb..dd28585 100644
--- a/src/tint/ast/transform/first_index_offset.cc
+++ b/src/tint/ast/transform/first_index_offset.cc
@@ -51,8 +51,8 @@
FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g) : binding(b), group(g) {}
FirstIndexOffset::BindingPoint::~BindingPoint() = default;
-FirstIndexOffset::Data::Data(bool has_vtx_or_inst_index)
- : has_vertex_or_instance_index(has_vtx_or_inst_index) {}
+FirstIndexOffset::Data::Data(bool has_vtx_index, bool has_inst_index)
+ : has_vertex_index(has_vtx_index), has_instance_index(has_inst_index) {}
FirstIndexOffset::Data::Data(const Data&) = default;
FirstIndexOffset::Data::~Data() = default;
@@ -81,7 +81,8 @@
std::unordered_map<const sem::Variable*, const char*> builtin_vars;
std::unordered_map<const type::StructMember*, const char*> builtin_members;
- bool has_vertex_or_instance_index = false;
+ bool has_vertex_index = false;
+ bool has_instance_index = false;
// Traverse the AST scanning for builtin accesses via variables (includes
// parameters) or structure member accesses.
@@ -93,12 +94,12 @@
if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_var = ctx.src->Sem().Get(var);
builtin_vars.emplace(sem_var, kFirstVertexName);
- has_vertex_or_instance_index = true;
+ has_vertex_index = true;
}
if (builtin == builtin::BuiltinValue::kInstanceIndex) {
auto* sem_var = ctx.src->Sem().Get(var);
builtin_vars.emplace(sem_var, kFirstInstanceName);
- has_vertex_or_instance_index = true;
+ has_instance_index = true;
}
}
}
@@ -110,19 +111,19 @@
if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_mem = ctx.src->Sem().Get(member);
builtin_members.emplace(sem_mem, kFirstVertexName);
- has_vertex_or_instance_index = true;
+ has_vertex_index = true;
}
if (builtin == builtin::BuiltinValue::kInstanceIndex) {
auto* sem_mem = ctx.src->Sem().Get(member);
builtin_members.emplace(sem_mem, kFirstInstanceName);
- has_vertex_or_instance_index = true;
+ has_instance_index = true;
}
}
}
}
}
- if (has_vertex_or_instance_index) {
+ if (has_vertex_index || has_instance_index) {
// Add uniform buffer members and calculate byte offsets
utils::Vector<const StructMember*, 8> members;
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
@@ -160,7 +161,7 @@
});
}
- outputs.Add<Data>(has_vertex_or_instance_index);
+ outputs.Add<Data>(has_vertex_index, has_instance_index);
ctx.Clone();
return Program(std::move(b));
diff --git a/src/tint/ast/transform/first_index_offset.h b/src/tint/ast/transform/first_index_offset.h
index 44a3a63..c1dd512 100644
--- a/src/tint/ast/transform/first_index_offset.h
+++ b/src/tint/ast/transform/first_index_offset.h
@@ -84,9 +84,9 @@
/// Data holds information about shader usage and constant buffer offsets.
struct Data final : public utils::Castable<Data, Transform::Data> {
/// Constructor
- /// @param has_vtx_or_inst_index True if the shader uses vertex_index or
- /// instance_index
- explicit Data(bool has_vtx_or_inst_index);
+ /// @param has_vtx_index True if the shader uses vertex_index
+ /// @param has_inst_index True if the shader uses instance_index
+ Data(bool has_vtx_index, bool has_inst_index);
/// Copy constructor
Data(const Data&);
@@ -95,7 +95,9 @@
~Data() override;
/// True if the shader uses vertex_index
- const bool has_vertex_or_instance_index;
+ const bool has_vertex_index;
+ /// True if the shader uses instance_index
+ const bool has_instance_index;
};
/// Constructor
diff --git a/src/tint/ast/transform/first_index_offset_test.cc b/src/tint/ast/transform/first_index_offset_test.cc
index 8dc441f..ce480e2 100644
--- a/src/tint/ast/transform/first_index_offset_test.cc
+++ b/src/tint/ast/transform/first_index_offset_test.cc
@@ -86,7 +86,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, false);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
@@ -130,7 +131,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
@@ -174,7 +176,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
@@ -218,7 +221,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
@@ -262,7 +266,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, false);
+ EXPECT_EQ(data->has_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
@@ -318,7 +323,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
@@ -374,7 +380,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, NestedCalls) {
@@ -426,7 +433,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
@@ -478,7 +486,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, false);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints) {
@@ -546,7 +555,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
@@ -614,7 +624,8 @@
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
- EXPECT_EQ(data->has_vertex_or_instance_index, true);
+ EXPECT_EQ(data->has_vertex_index, true);
+ EXPECT_EQ(data->has_instance_index, true);
}
} // namespace
diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc
index 5d82479..ada1603 100644
--- a/src/tint/cmd/main.cc
+++ b/src/tint/cmd/main.cc
@@ -41,6 +41,8 @@
#include "src/tint/ast/module.h"
#include "src/tint/cmd/generate_external_texture_bindings.h"
#include "src/tint/cmd/helper.h"
+#include "src/tint/utils/cli.h"
+#include "src/tint/utils/defer.h"
#include "src/tint/utils/io/command.h"
#include "src/tint/utils/string.h"
#include "src/tint/utils/string_stream.h"
@@ -54,6 +56,36 @@
#include "src/tint/ir/module.h" // nogncheck
#endif // TINT_BUILD_IR
+#if TINT_BUILD_SPV_WRITER
+#define SPV_WRITER_ONLY(x) x
+#else
+#define SPV_WRITER_ONLY(x)
+#endif
+
+#if TINT_BUILD_WGSL_WRITER
+#define WGSL_WRITER_ONLY(x) x
+#else
+#define WGSL_WRITER_ONLY(x)
+#endif
+
+#if TINT_BUILD_MSL_WRITER
+#define MSL_WRITER_ONLY(x) x
+#else
+#define MSL_WRITER_ONLY(x)
+#endif
+
+#if TINT_BUILD_HLSL_WRITER
+#define HLSL_WRITER_ONLY(x) x
+#else
+#define HLSL_WRITER_ONLY(x)
+#endif
+
+#if TINT_BUILD_GLSL_WRITER
+#define GLSL_WRITER_ONLY(x) x
+#else
+#define GLSL_WRITER_ONLY(x)
+#endif
+
namespace {
/// Prints the given hash value in a format string that the end-to-end test runner can parse.
@@ -73,7 +105,6 @@
};
struct Options {
- bool show_help = false;
bool verbose = false;
std::string input_filename;
@@ -99,12 +130,12 @@
tint::reader::spirv::Options spirv_reader_options;
#endif
- std::vector<std::string> transforms;
+ tint::utils::Vector<std::string, 4> transforms;
std::string fxc_path;
std::string dxc_path;
std::string xcrun_path;
- std::unordered_map<std::string, double> overrides;
+ tint::utils::Hashmap<std::string, double, 8> overrides;
std::optional<tint::sem::BindingPoint> hlsl_root_constant_binding_point;
#if TINT_BUILD_IR
@@ -113,96 +144,10 @@
#endif // TINT_BUILD_IR
#if TINT_BUILD_SYNTAX_TREE_WRITER
- bool dump_syntax_tree = false;
-#endif // TINB_BUILD_SYNTAX_TREE_WRITER
+ bool dump_ast = false;
+#endif // TINT_BUILD_SYNTAX_TREE_WRITER
};
-const char kUsage[] = R"(Usage: tint [options] <input-file>
-
- options:
- --format <spirv|spvasm|wgsl|msl|hlsl|none> -- Output format.
- If not provided, will be inferred from output
- filename extension:
- .spvasm -> spvasm
- .spv -> spirv
- .wgsl -> wgsl
- .metal -> msl
- .hlsl -> hlsl
- If none matches, then default to SPIR-V assembly.
- -ep <name> -- Output single entry point
- --output-file <name> -- Output file name. Use "-" for standard output
- -o <name> -- Output file name. Use "-" for standard output
- --transform <name list> -- Runs transforms, name list is comma separated
- Available transforms:
-${transforms} --parse-only -- Stop after parsing the input
- --allow-non-uniform-derivatives -- When using SPIR-V input, allow non-uniform derivatives by
- inserting a module-scope directive to suppress any uniformity
- violations that may be produced.
- --disable-workgroup-init -- Disable workgroup memory zero initialization.
- --dump-inspector-bindings -- Dump reflection data about bindins to stdout.
- -h -- This help text
- --hlsl-root-constant-binding-point <group>,<binding> -- Binding point for root constant.
- Specify the binding point for generated uniform buffer
- used for num_workgroups in HLSL. If not specified, then
- default to binding 0 of the largest used group plus 1,
- or group 0 if no resource bound.
- --validate -- Validates the generated shader with all available validators
- --skip-hash <hash list> -- Skips validation if the hash of the output is equal to any
- of the hash codes in the comma separated list of hashes
- --print-hash -- Emit the hash of the output program
- --fxc -- Path to FXC dll, used to validate HLSL output.
- When specified, automatically enables HLSL validation with FXC
- --dxc -- Path to DXC executable, used to validate HLSL output.
- When specified, automatically enables HLSL validation with DXC
- --xcrun -- Path to xcrun executable, used to validate MSL output.
- When specified, automatically enables MSL validation
- --overrides -- Override values as IDENTIFIER=VALUE, comma-separated.
- --rename-all -- Renames all symbols.
-)";
-
-Format parse_format(const std::string& fmt) {
- (void)fmt;
-
-#if TINT_BUILD_SPV_WRITER
- if (fmt == "spirv") {
- return Format::kSpirv;
- }
- if (fmt == "spvasm") {
- return Format::kSpvAsm;
- }
-#endif // TINT_BUILD_SPV_WRITER
-
-#if TINT_BUILD_WGSL_WRITER
- if (fmt == "wgsl") {
- return Format::kWgsl;
- }
-#endif // TINT_BUILD_WGSL_WRITER
-
-#if TINT_BUILD_MSL_WRITER
- if (fmt == "msl") {
- return Format::kMsl;
- }
-#endif // TINT_BUILD_MSL_WRITER
-
-#if TINT_BUILD_HLSL_WRITER
- if (fmt == "hlsl") {
- return Format::kHlsl;
- }
-#endif // TINT_BUILD_HLSL_WRITER
-
-#if TINT_BUILD_GLSL_WRITER
- if (fmt == "glsl") {
- return Format::kGlsl;
- }
-#endif // TINT_BUILD_GLSL_WRITER
-
- if (fmt == "none") {
- return Format::kNone;
- }
-
- return Format::kUnknown;
-}
-
/// @param filename the filename to inspect
/// @returns the inferred format for the filename suffix
Format infer_format(const std::string& filename) {
@@ -238,112 +183,154 @@
return Format::kUnknown;
}
-std::vector<std::string> split_on_char(std::string list, char c) {
- std::vector<std::string> res;
+bool ParseArgs(tint::utils::VectorRef<std::string_view> arguments,
+ std::string transform_names,
+ Options* opts) {
+ using namespace tint::utils::cli; // NOLINT(build/namespaces)
- std::istringstream str(list);
- while (str.good()) {
- std::string substr;
- getline(str, substr, c);
- res.push_back(substr);
- }
- return res;
-}
+ tint::utils::Vector<EnumName<Format>, 8> format_enum_names{
+ EnumName(Format::kNone, "none"),
+ };
-std::vector<std::string> split_on_comma(std::string list) {
- return split_on_char(list, ',');
-}
+ SPV_WRITER_ONLY(format_enum_names.Emplace(Format::kSpirv, "spirv"));
+ SPV_WRITER_ONLY(format_enum_names.Emplace(Format::kSpvAsm, "spvasm"));
+ WGSL_WRITER_ONLY(format_enum_names.Emplace(Format::kWgsl, "wgsl"));
+ MSL_WRITER_ONLY(format_enum_names.Emplace(Format::kMsl, "msl"));
+ HLSL_WRITER_ONLY(format_enum_names.Emplace(Format::kHlsl, "hlsl"));
+ GLSL_WRITER_ONLY(format_enum_names.Emplace(Format::kGlsl, "glsl"));
-std::vector<std::string> split_on_equal(std::string list) {
- return split_on_char(list, '=');
-}
+ OptionSet options;
+ auto& fmt = options.Add<EnumOption<Format>>("format",
+ R"(Output format.
+If not provided, will be inferred from output filename extension:
+ .spvasm -> spvasm
+ .spv -> spirv
+ .wgsl -> wgsl
+ .metal -> msl
+ .hlsl -> hlsl)",
+ format_enum_names, ShortName{"f"});
+ TINT_DEFER(opts->format = fmt.value.value_or(Format::kUnknown));
-std::optional<uint64_t> parse_unsigned_number(std::string number) {
- for (char c : number) {
- if (!std::isdigit(c)) {
- // Found a non-digital char, return nullopt
- return std::nullopt;
- }
- }
-
- errno = 0;
- char* p_end;
- uint64_t result;
- // std::strtoull will not throw exception.
- result = std::strtoull(number.c_str(), &p_end, 10);
- if ((errno != 0) || (static_cast<size_t>(p_end - number.c_str()) != number.length())) {
- // Unexpected conversion result
- return std::nullopt;
- }
-
- return result;
-}
-
-bool ParseArgs(const std::vector<std::string>& args, Options* opts) {
- for (size_t i = 1; i < args.size(); ++i) {
- const std::string& arg = args[i];
- if (arg == "--format") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for --format argument." << std::endl;
- return false;
- }
- opts->format = parse_format(args[i]);
-
- if (opts->format == Format::kUnknown) {
- std::cerr << "Unknown output format: " << args[i] << std::endl;
- return false;
- }
- } else if (arg == "-ep") {
- if (i + 1 >= args.size()) {
- std::cerr << "Missing value for -ep" << std::endl;
- return false;
- }
- i++;
- opts->ep_name = args[i];
+ auto& ep = options.Add<StringOption>("entry-point", "Output single entry point",
+ ShortName{"ep"}, Parameter{"name"});
+ TINT_DEFER({
+ if (ep.value.has_value()) {
+ opts->ep_name = *ep.value;
opts->emit_single_entry_point = true;
+ }
+ });
- } else if (arg == "-o" || arg == "--output-name") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
- return false;
- }
- opts->output_file = args[i];
+ auto& output = options.Add<StringOption>("output-name", "Output file name", ShortName{"o"},
+ Parameter{"name"});
+ TINT_DEFER(opts->output_file = output.value.value_or(""));
- } else if (arg == "-h" || arg == "--help") {
- opts->show_help = true;
- } else if (arg == "-v" || arg == "--verbose") {
- opts->verbose = true;
- } else if (arg == "--transform") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
- return false;
- }
- opts->transforms = split_on_comma(args[i]);
- } else if (arg == "--parse-only") {
- opts->parse_only = true;
- } else if (arg == "--allow-non-uniform-derivatives") {
-#if TINT_BUILD_SPV_READER
- opts->spirv_reader_options.allow_non_uniform_derivatives = true;
-#else
- std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl;
- return false;
-#endif
- } else if (arg == "--disable-workgroup-init") {
- opts->disable_workgroup_init = true;
- } else if (arg == "--dump-inspector-bindings") {
- opts->dump_inspector_bindings = true;
- } else if (arg == "--validate") {
+ auto& fxc_path =
+ options.Add<StringOption>("fxc", R"(Path to FXC dll, used to validate HLSL output.
+When specified, automatically enables HLSL validation with FXC)",
+ Parameter{"path"});
+ TINT_DEFER(opts->fxc_path = fxc_path.value.value_or(""));
+
+ auto& dxc_path =
+ options.Add<StringOption>("dxc", R"(Path to DXC dll, used to validate HLSL output.
+When specified, automatically enables HLSL validation with DXC)",
+ Parameter{"path"});
+ TINT_DEFER(opts->dxc_path = dxc_path.value.value_or(""));
+
+ auto& xcrun =
+ options.Add<StringOption>("xcrun", R"(Path to xcrun executable, used to validate MSL output.
+When specified, automatically enables MSL validation)",
+ Parameter{"path"});
+ TINT_DEFER({
+ if (xcrun.value.has_value()) {
+ opts->xcrun_path = *xcrun.value;
opts->validate = true;
- } else if (arg == "--skip-hash") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing hash value for " << arg << std::endl;
- return false;
+ }
+ });
+
+#if TINT_BUILD_IR
+ auto& dump_ir = options.Add<BoolOption>("dump-ir", "Writes the IR to stdout", Alias{"emit-ir"},
+ Default{false});
+ TINT_DEFER(opts->dump_ir = *dump_ir.value);
+
+ auto& use_ir = options.Add<BoolOption>(
+ "use-ir", "Use the IR for writers and transforms when possible", Default{false});
+ TINT_DEFER(opts->use_ir = *use_ir.value);
+#endif // TINT_BUILD_IR
+
+ auto& verbose =
+ options.Add<BoolOption>("verbose", "Verbose output", ShortName{"v"}, Default{false});
+ TINT_DEFER(opts->verbose = *verbose.value);
+
+ auto& validate = options.Add<BoolOption>(
+ "validate", "Validates the generated shader with all available validators", Default{false});
+ TINT_DEFER(opts->validate = *validate.value);
+
+ auto& parse_only =
+ options.Add<BoolOption>("parse-only", "Stop after parsing the input", Default{false});
+ TINT_DEFER(opts->parse_only = *parse_only.value);
+
+#if TINT_BUILD_SPV_READER
+ auto& allow_nud =
+ options.Add<BoolOption>("allow-non-uniform-derivatives",
+ R"(When using SPIR-V input, allow non-uniform derivatives by
+inserting a module-scope directive to suppress any uniformity
+violations that may be produced)",
+ Default{false});
+ TINT_DEFER({
+ if (allow_nud.value.value_or(false)) {
+ opts->spirv_reader_options.allow_non_uniform_derivatives = true;
+ }
+ });
+#endif
+
+ auto& disable_wg_init = options.Add<BoolOption>(
+ "disable-workgroup-init", "Disable workgroup memory zero initialization", Default{false});
+ TINT_DEFER(opts->disable_workgroup_init = *disable_wg_init.value);
+
+ auto& rename_all = options.Add<BoolOption>("rename-all", "Renames all symbols", Default{false});
+ TINT_DEFER(opts->rename_all = *rename_all.value);
+
+ auto& dump_inspector_bindings = options.Add<BoolOption>(
+ "dump-inspector-bindings", "Dump reflection data about bindings to stdout",
+ Alias{"emit-inspector-bindings"}, Default{false});
+ TINT_DEFER(opts->dump_inspector_bindings = *dump_inspector_bindings.value);
+
+#if TINT_BUILD_SYNTAX_TREE_WRITER
+ auto& dump_ast = options.Add<BoolOption>("dump-ast", "Writes the AST to stdout",
+ Alias{"emit-ast"}, Default{false});
+ TINT_DEFER(opts->dump_ast = *dump_ast.value);
+#endif // TINT_BUILD_SYNTAX_TREE_WRITER
+
+ auto& print_hash = options.Add<BoolOption>("print-hash", "Emit the hash of the output program",
+ Default{false});
+ TINT_DEFER(opts->print_hash = *print_hash.value);
+
+ auto& transforms =
+ options.Add<StringOption>("transform", R"(Runs transforms, name list is comma separated
+Available transforms:
+)" + transform_names,
+ ShortName{"t"});
+ TINT_DEFER({
+ if (transforms.value.has_value()) {
+ for (auto transform : tint::utils::Split(*transforms.value, ",")) {
+ opts->transforms.Push(std::string(transform));
}
- for (auto hash : split_on_comma(args[i])) {
+ }
+ });
+
+ auto& hlsl_rc_bp = options.Add<StringOption>("hlsl-root-constant-binding-point",
+ R"(Binding point for root constant.
+Specify the binding point for generated uniform buffer
+used for num_workgroups in HLSL. If not specified, then
+default to binding 0 of the largest used group plus 1,
+or group 0 if no resource bound)");
+
+ auto& skip_hash = options.Add<StringOption>(
+ "skip-hash", R"(Skips validation if the hash of the output is equal to any
+of the hash codes in the comma separated list of hashes)");
+ TINT_DEFER({
+ if (skip_hash.value.has_value()) {
+ for (auto hash : tint::utils::Split(*skip_hash.value, ",")) {
uint32_t value = 0;
int base = 10;
if (hash.size() > 2 && hash[0] == '0' && (hash[1] == 'x' || hash[1] == 'X')) {
@@ -353,91 +340,82 @@
std::from_chars(hash.data(), hash.data() + hash.size(), value, base);
opts->skip_hash.emplace(value);
}
- } else if (arg == "--print-hash") {
- opts->print_hash = true;
- } else if (arg == "--fxc") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
+ }
+ });
+
+ auto& overrides = options.Add<StringOption>(
+ "overrides", "Override values as IDENTIFIER=VALUE, comma-separated");
+
+ auto& help = options.Add<BoolOption>("help", "Show usage", ShortName{"h"});
+
+ auto show_usage = [&] {
+ std::cout << R"(Usage: tint [options] <input-file>
+
+Options:
+)";
+ options.ShowHelp(std::cout);
+ };
+
+ auto result = options.Parse(std::cerr, arguments);
+ if (!result) {
+ std::cerr << std::endl;
+ show_usage();
+ return false;
+ }
+ if (help.value.value_or(false)) {
+ show_usage();
+ return false;
+ }
+
+ if (overrides.value.has_value()) {
+ for (const auto& o : tint::utils::Split(*overrides.value, ",")) {
+ auto parts = tint::utils::Split(o, "=");
+ if (parts.Length() != 2) {
+ std::cerr << "override values must be of the form IDENTIFIER=VALUE";
return false;
}
- opts->fxc_path = args[i];
- } else if (arg == "--dxc") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
+ auto value = tint::utils::ParseNumber<double>(parts[1]);
+ if (!value) {
+ std::cerr << "invalid override value: " << parts[1];
return false;
}
- opts->dxc_path = args[i];
-#if TINT_BUILD_IR
- } else if (arg == "--dump-ir") {
- opts->dump_ir = true;
- } else if (arg == "--use-ir") {
- opts->use_ir = true;
-#endif // TINT_BUILD_IR
-#if TINT_BUILD_SYNTAX_TREE_WRITER
- } else if (arg == "--dump-ast") {
- opts->dump_syntax_tree = true;
-#endif // TINT_BUILD_SYNTAX_TREE_WRITER
- } else if (arg == "--xcrun") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
- return false;
- }
- opts->xcrun_path = args[i];
- opts->validate = true;
- } else if (arg == "--overrides") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
- return false;
- }
- for (const auto& o : split_on_comma(args[i])) {
- auto parts = split_on_equal(o);
- opts->overrides.insert({parts[0], std::stod(parts[1])});
- }
- } else if (arg == "--rename-all") {
- ++i;
- opts->rename_all = true;
- } else if (arg == "--hlsl-root-constant-binding-point") {
- ++i;
- if (i >= args.size()) {
- std::cerr << "Missing value for " << arg << std::endl;
- return false;
- }
- auto binding_points = split_on_comma(args[i]);
- if (binding_points.size() != 2) {
- std::cerr << "Invalid binding point for " << arg << ": " << args[i] << std::endl;
- return false;
- }
- auto group = parse_unsigned_number(binding_points[0]);
- if ((!group.has_value()) || (group.value() > std::numeric_limits<uint32_t>::max())) {
- std::cerr << "Invalid group for " << arg << ": " << binding_points[0] << std::endl;
- return false;
- }
- auto binding = parse_unsigned_number(binding_points[1]);
- if ((!binding.has_value()) ||
- (binding.value() > std::numeric_limits<uint32_t>::max())) {
- std::cerr << "Invalid binding for " << arg << ": " << binding_points[1]
- << std::endl;
- return false;
- }
- opts->hlsl_root_constant_binding_point = tint::sem::BindingPoint{
- static_cast<uint32_t>(group.value()), static_cast<uint32_t>(binding.value())};
- } else if (!arg.empty()) {
- if (arg[0] == '-') {
- std::cerr << "Unrecognized option: " << arg << std::endl;
- return false;
- }
- if (!opts->input_filename.empty()) {
- std::cerr << "More than one input file specified: '" << opts->input_filename
- << "' and '" << arg << "'" << std::endl;
- return false;
- }
- opts->input_filename = arg;
+ opts->overrides.Add(std::string(parts[0]), value.Get());
}
}
+
+ if (hlsl_rc_bp.value.has_value()) {
+ auto binding_points = tint::utils::Split(*hlsl_rc_bp.value, ",");
+ if (binding_points.Length() != 2) {
+ std::cerr << "Invalid binding point for " << hlsl_rc_bp.name << ": "
+ << *hlsl_rc_bp.value << std::endl;
+ return false;
+ }
+ auto group = tint::utils::ParseUint32(binding_points[0]);
+ if (!group) {
+ std::cerr << "Invalid group for " << hlsl_rc_bp.name << ": " << binding_points[0]
+ << std::endl;
+ return false;
+ }
+ auto binding = tint::utils::ParseUint32(binding_points[1]);
+ if (!binding) {
+ std::cerr << "Invalid binding for " << hlsl_rc_bp.name << ": " << binding_points[1]
+ << std::endl;
+ return false;
+ }
+ opts->hlsl_root_constant_binding_point =
+ tint::sem::BindingPoint{group.Get(), binding.Get()};
+ }
+
+ auto files = result.Get();
+ if (files.Length() > 1) {
+ std::cerr << "More than one input file specified: "
+ << tint::utils::Join(Transform(files, tint::utils::Quote), ", ") << std::endl;
+ return false;
+ }
+ if (files.Length() == 1) {
+ opts->input_filename = files[0];
+ }
+
return true;
}
@@ -677,7 +655,7 @@
const char* default_xcrun_exe = "xcrun";
#endif
auto xcrun = tint::utils::Command::LookPath(
- options.xcrun_path.empty() ? default_xcrun_exe : options.xcrun_path);
+ options.xcrun_path.empty() ? default_xcrun_exe : std::string(options.xcrun_path));
if (xcrun.Found()) {
res = tint::val::Msl(xcrun.Path(), result.msl);
} else {
@@ -738,8 +716,8 @@
tint::val::Result dxc_res;
bool dxc_found = false;
if (options.validate || must_validate_dxc) {
- auto dxc =
- tint::utils::Command::LookPath(options.dxc_path.empty() ? "dxc" : options.dxc_path);
+ auto dxc = tint::utils::Command::LookPath(
+ options.dxc_path.empty() ? "dxc" : std::string(options.dxc_path));
if (dxc.Found()) {
dxc_found = true;
@@ -757,8 +735,8 @@
} else if (must_validate_dxc) {
// DXC was explicitly requested. Error if it could not be found.
dxc_res.failed = true;
- dxc_res.output =
- "DXC executable '" + options.dxc_path + "' not found. Cannot validate";
+ dxc_res.output = "DXC executable '" + std::string(options.dxc_path) +
+ "' not found. Cannot validate";
}
}
@@ -766,7 +744,7 @@
bool fxc_found = false;
if (options.validate || must_validate_fxc) {
auto fxc = tint::utils::Command::LookPath(
- options.fxc_path.empty() ? tint::val::kFxcDLLName : options.fxc_path);
+ options.fxc_path.empty() ? tint::val::kFxcDLLName : std::string(options.fxc_path));
#ifdef _WIN32
if (fxc.Found()) {
@@ -913,7 +891,14 @@
} // namespace
int main(int argc, const char** argv) {
- std::vector<std::string> args(argv, argv + argc);
+ tint::utils::Vector<std::string_view, 8> arguments;
+ for (int i = 1; i < argc; i++) {
+ std::string_view arg(argv[i]);
+ if (!arg.empty()) {
+ arguments.Push(argv[i]);
+ }
+ }
+
Options options;
tint::SetInternalCompilerErrorReporter(&tint::cmd::TintInternalCompilerErrorReporter);
@@ -928,11 +913,6 @@
};
#endif // TINT_BUILD_WGSL_WRITER
- if (!ParseArgs(args, &options)) {
- std::cerr << "Failed to parse arguments." << std::endl;
- return 1;
- }
-
struct TransformFactory {
const char* name;
/// Build and adds the transform to the transform manager.
@@ -970,22 +950,23 @@
tint::ast::transform::SubstituteOverride::Config cfg;
std::unordered_map<tint::OverrideId, double> values;
- values.reserve(options.overrides.size());
+ values.reserve(options.overrides.Count());
- for (const auto& [name, value] : options.overrides) {
+ for (auto override : options.overrides) {
+ const auto& name = override.key;
+ const auto& value = override.value;
if (name.empty()) {
- std::cerr << "empty override name";
+ std::cerr << "empty override name" << std::endl;
return false;
}
- if (isdigit(name[0])) {
- tint::OverrideId id{
- static_cast<decltype(tint::OverrideId::value)>(atoi(name.c_str()))};
+ if (auto num = tint::utils::ParseNumber<decltype(tint::OverrideId::value)>(name)) {
+ tint::OverrideId id{num.Get()};
values.emplace(id, value);
} else {
auto override_names = inspector.GetNamedOverrideIds();
auto it = override_names.find(name);
if (it == override_names.end()) {
- std::cerr << "unknown override '" << name << "'";
+ std::cerr << "unknown override '" << name << "'" << std::endl;
return false;
}
values.emplace(it->second, value);
@@ -1007,20 +988,8 @@
return names.str();
};
- if (options.show_help) {
- std::string usage = tint::utils::ReplaceAll(kUsage, "${transforms}", transform_names());
-#if TINT_BUILD_IR
- usage +=
- " --dump-ir -- Writes the IR to stdout\n"
- " --dump-ir-graph -- Writes the IR graph to 'tint.dot' as a dot graph\n"
- " --use-ir -- Use the IR for writers and transforms when possible\n";
-#endif // TINT_BUILD_IR
-#if TINT_BUILD_SYNTAX_TREE_WRITER
- usage += " --dump-ast -- Writes the AST to stdout\n";
-#endif // TINT_BUILD_SYNTAX_TREE_WRITER
-
- std::cout << usage << std::endl;
- return 0;
+ if (!ParseArgs(arguments, transform_names(), &options)) {
+ return 1;
}
// Implement output format defaults.
@@ -1056,7 +1025,7 @@
}
#if TINT_BUILD_SYNTAX_TREE_WRITER
- if (options.dump_syntax_tree) {
+ if (options.dump_ast) {
tint::writer::syntax_tree::Options gen_options;
auto result = tint::writer::syntax_tree::Generate(program.get(), gen_options);
if (!result.success) {
@@ -1138,12 +1107,12 @@
}
std::cerr << "Unknown transform: " << name << std::endl;
- std::cerr << "Available transforms: " << std::endl << transform_names();
+ std::cerr << "Available transforms: " << std::endl << transform_names() << std::endl;
return false;
};
// If overrides are provided, add the SubstituteOverride transform.
- if (!options.overrides.empty()) {
+ if (!options.overrides.IsEmpty()) {
if (!enable_transform("substitute_override")) {
return 1;
}
diff --git a/src/tint/constant/scalar.cc b/src/tint/constant/scalar.cc
index b77cb3b..9a9e551 100644
--- a/src/tint/constant/scalar.cc
+++ b/src/tint/constant/scalar.cc
@@ -14,6 +14,7 @@
#include "src/tint/constant/scalar.h"
+TINT_INSTANTIATE_TYPEINFO(tint::constant::ScalarBase);
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<tint::AInt>);
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<tint::AFloat>);
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<tint::i32>);
@@ -21,3 +22,9 @@
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<tint::f16>);
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<tint::f32>);
TINT_INSTANTIATE_TYPEINFO(tint::constant::Scalar<bool>);
+
+namespace tint::constant {
+
+ScalarBase::~ScalarBase() = default;
+
+}
diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h
index 474c9ad..23684ec 100644
--- a/src/tint/constant/scalar.h
+++ b/src/tint/constant/scalar.h
@@ -24,9 +24,16 @@
namespace tint::constant {
+/// ScalarBase is the base class of all Scalar<T> specializations.
+/// Used for querying whether a value is a scalar type.
+class ScalarBase : public utils::Castable<ScalarBase, Value> {
+ public:
+ ~ScalarBase() override;
+};
+
/// Scalar holds a single scalar or abstract-numeric value.
template <typename T>
-class Scalar : public utils::Castable<Scalar<T>, Value> {
+class Scalar : public utils::Castable<Scalar<T>, ScalarBase> {
public:
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool");
diff --git a/src/tint/ir/break_if.cc b/src/tint/ir/break_if.cc
index f19fb79..c17ef54 100644
--- a/src/tint/ir/break_if.cc
+++ b/src/tint/ir/break_if.cc
@@ -14,19 +14,23 @@
#include "src/tint/ir/break_if.h"
+#include <utility>
+
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::BreakIf);
namespace tint::ir {
-BreakIf::BreakIf(Value* condition, ir::Loop* loop)
- : Base(utils::Empty), condition_(condition), loop_(loop) {
+BreakIf::BreakIf(Value* condition,
+ ir::Loop* loop,
+ utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), condition_(condition), loop_(loop) {
TINT_ASSERT(IR, condition_);
TINT_ASSERT(IR, loop_);
condition_->AddUsage(this);
loop_->AddUsage(this);
- loop_->Start()->AddInboundBranch(this);
+ loop_->Body()->AddInboundBranch(this);
loop_->Merge()->AddInboundBranch(this);
}
diff --git a/src/tint/ir/break_if.h b/src/tint/ir/break_if.h
index 47fd4e8..9769378 100644
--- a/src/tint/ir/break_if.h
+++ b/src/tint/ir/break_if.h
@@ -32,7 +32,8 @@
/// Constructor
/// @param condition the break condition
/// @param loop the loop containing the break-if
- BreakIf(Value* condition, ir::Loop* loop);
+ /// @param args the branch arguments
+ BreakIf(Value* condition, ir::Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
~BreakIf() override;
/// @returns the break condition
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 7f2d9c3..4706d59 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -54,8 +54,8 @@
return ir.values.Create<If>(condition, CreateBlock(), CreateBlock(), CreateBlock());
}
-Loop* Builder::CreateLoop() {
- return ir.values.Create<Loop>(CreateBlock(), CreateBlock(), CreateBlock());
+Loop* Builder::CreateLoop(utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<Loop>(CreateBlock(), CreateBlock(), CreateBlock(), std::move(args));
}
Switch* Builder::CreateSwitch(Value* condition) {
@@ -201,31 +201,36 @@
return ir.values.Create<ir::Var>(type);
}
-ir::Return* Builder::Return(Function* func, utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::Return>(func, args);
+ir::Return* Builder::Return(Function* func, utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::Return>(func, std::move(args));
}
-ir::NextIteration* Builder::NextIteration(Loop* loop) {
- return ir.values.Create<ir::NextIteration>(loop);
+ir::NextIteration* Builder::NextIteration(Loop* loop,
+ utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::NextIteration>(loop, std::move(args));
}
-ir::BreakIf* Builder::BreakIf(Value* condition, Loop* loop) {
- return ir.values.Create<ir::BreakIf>(condition, loop);
+ir::BreakIf* Builder::BreakIf(Value* condition,
+ Loop* loop,
+ utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::BreakIf>(condition, loop, std::move(args));
}
-ir::Continue* Builder::Continue(Loop* loop) {
- return ir.values.Create<ir::Continue>(loop);
-}
-ir::ExitSwitch* Builder::ExitSwitch(Switch* sw) {
- return ir.values.Create<ir::ExitSwitch>(sw);
+ir::Continue* Builder::Continue(Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::Continue>(loop, std::move(args));
}
-ir::ExitLoop* Builder::ExitLoop(Loop* loop) {
- return ir.values.Create<ir::ExitLoop>(loop);
+ir::ExitSwitch* Builder::ExitSwitch(Switch* sw,
+ utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::ExitSwitch>(sw, std::move(args));
}
-ir::ExitIf* Builder::ExitIf(If* i, utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::ExitIf>(i, args);
+ir::ExitLoop* Builder::ExitLoop(Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::ExitLoop>(loop, std::move(args));
+}
+
+ir::ExitIf* Builder::ExitIf(If* i, utils::VectorRef<Value*> args /* = utils::Empty */) {
+ return ir.values.Create<ir::ExitIf>(i, std::move(args));
}
ir::BlockParam* Builder::BlockParam(const type::Type* type) {
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index b544cb2..092162c 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -84,8 +84,9 @@
If* CreateIf(Value* condition);
/// Creates a loop flow node
+ /// @param args the branch arguments
/// @returns the flow node
- Loop* CreateLoop();
+ Loop* CreateLoop(utils::VectorRef<Value*> args = utils::Empty);
/// Creates a switch flow node
/// @param condition the switch condition
@@ -336,39 +337,46 @@
/// @param func the function being returned
/// @param args the return arguments
/// @returns the instruction
- ir::Return* Return(Function* func, utils::VectorRef<Value*> args = {});
+ ir::Return* Return(Function* func, utils::VectorRef<Value*> args = utils::Empty);
/// Creates a loop next iteration instruction
/// @param loop the loop being iterated
+ /// @param args the branch arguments
/// @returns the instruction
- ir::NextIteration* NextIteration(Loop* loop);
+ ir::NextIteration* NextIteration(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
/// Creates a loop break-if instruction
/// @param condition the break condition
/// @param loop the loop being iterated
+ /// @param args the branch arguments
/// @returns the instruction
- ir::BreakIf* BreakIf(Value* condition, Loop* loop);
+ ir::BreakIf* BreakIf(Value* condition,
+ Loop* loop,
+ utils::VectorRef<Value*> args = utils::Empty);
/// Creates a continue instruction
/// @param loop the loop being continued
+ /// @param args the branch arguments
/// @returns the instruction
- ir::Continue* Continue(Loop* loop);
+ ir::Continue* Continue(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
/// Creates an exit switch instruction
/// @param sw the switch being exited
+ /// @param args the branch arguments
/// @returns the instruction
- ir::ExitSwitch* ExitSwitch(Switch* sw);
+ ir::ExitSwitch* ExitSwitch(Switch* sw, utils::VectorRef<Value*> args = utils::Empty);
/// Creates an exit loop instruction
/// @param loop the loop being exited
+ /// @param args the branch arguments
/// @returns the instruction
- ir::ExitLoop* ExitLoop(Loop* loop);
+ ir::ExitLoop* ExitLoop(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
/// Creates an exit if instruction
/// @param i the if being exited
/// @param args the branch arguments
/// @returns the instruction
- ir::ExitIf* ExitIf(If* i, utils::VectorRef<Value*> args = {});
+ ir::ExitIf* ExitIf(If* i, utils::VectorRef<Value*> args = utils::Empty);
/// Creates a new `BlockParam`
/// @param type the parameter type
diff --git a/src/tint/ir/continue.cc b/src/tint/ir/continue.cc
index a4511ae..7edebf8 100644
--- a/src/tint/ir/continue.cc
+++ b/src/tint/ir/continue.cc
@@ -14,13 +14,16 @@
#include "src/tint/ir/continue.h"
+#include <utility>
+
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Continue);
namespace tint::ir {
-Continue::Continue(ir::Loop* loop) : Base(utils::Empty), loop_(loop) {
+Continue::Continue(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), loop_(loop) {
TINT_ASSERT(IR, loop_);
loop_->AddUsage(this);
loop_->Continuing()->AddInboundBranch(this);
diff --git a/src/tint/ir/continue.h b/src/tint/ir/continue.h
index aea601a..a954e74 100644
--- a/src/tint/ir/continue.h
+++ b/src/tint/ir/continue.h
@@ -30,7 +30,8 @@
public:
/// Constructor
/// @param loop the loop owning the continue block
- explicit Continue(ir::Loop* loop);
+ /// @param args the branch arguments
+ explicit Continue(ir::Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
~Continue() override;
/// @returns the loop owning the continue block
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index d8e472a..c7155d5 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -122,12 +122,7 @@
Indent() << "%b" << IdOf(blk) << " = block";
if (!blk->Params().IsEmpty()) {
out_ << " (";
- for (auto* p : blk->Params()) {
- if (p != blk->Params().Front()) {
- out_ << ", ";
- }
- EmitValue(p);
- }
+ EmitValueList(blk->Params());
out_ << ")";
}
@@ -435,7 +430,7 @@
}
void Disassembler::EmitLoop(const Loop* l) {
- out_ << "loop [s: %b" << IdOf(l->Start());
+ out_ << "loop [s: %b" << IdOf(l->Body());
if (l->Continuing()->HasBranchTarget()) {
out_ << ", c: %b" << IdOf(l->Continuing());
@@ -443,11 +438,18 @@
if (l->Merge()->HasBranchTarget()) {
out_ << ", m: %b" << IdOf(l->Merge());
}
- out_ << "]" << std::endl;
+ out_ << "]";
+
+ if (!l->Args().IsEmpty()) {
+ out_ << " ";
+ EmitValueList(l->Args());
+ }
+
+ out_ << std::endl;
{
ScopedIndent si(indent_size_);
- Walk(l->Start());
+ Walk(l->Body());
out_ << std::endl;
}
@@ -515,38 +517,35 @@
[&](const ir::ExitSwitch* es) { out_ << "exit_switch %b" << IdOf(es->Switch()->Merge()); },
[&](const ir::ExitLoop* el) { out_ << "exit_loop %b" << IdOf(el->Loop()->Merge()); },
[&](const ir::NextIteration* ni) {
- out_ << "next_iteration %b" << IdOf(ni->Loop()->Start());
+ out_ << "next_iteration %b" << IdOf(ni->Loop()->Body());
},
[&](const ir::BreakIf* bi) {
out_ << "break_if ";
EmitValue(bi->Condition());
- out_ << " %b" << IdOf(bi->Loop()->Start());
+ out_ << " %b" << IdOf(bi->Loop()->Body());
},
[&](Default) { out_ << "Unknown branch " << b->TypeInfo().name; });
if (!b->Args().IsEmpty()) {
out_ << " ";
- for (auto* v : b->Args()) {
- if (v != b->Args().Front()) {
- out_ << ", ";
- }
- EmitValue(v);
- }
+ EmitValueList(b->Args());
}
out_ << std::endl;
}
-void Disassembler::EmitArgs(const Call* call) {
- bool first = true;
- for (const auto* arg : call->Args()) {
- if (!first) {
+void Disassembler::EmitValueList(tint::utils::VectorRef<const tint::ir::Value*> values) {
+ for (auto* v : values) {
+ if (v != values.Front()) {
out_ << ", ";
}
- first = false;
- EmitValue(arg);
+ EmitValue(v);
}
}
+void Disassembler::EmitArgs(const Call* call) {
+ EmitValueList(call->Args());
+}
+
void Disassembler::EmitBinary(const Binary* b) {
EmitValueWithType(b);
out_ << " = ";
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 9194942..4d5d928 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -66,6 +66,7 @@
void EmitInstruction(const Instruction* inst);
void EmitValueWithType(const Value* val);
void EmitValue(const Value* val);
+ void EmitValueList(tint::utils::VectorRef<const tint::ir::Value*> values);
void EmitArgs(const Call* call);
void EmitBinary(const Binary* b);
void EmitUnary(const Unary* b);
diff --git a/src/tint/ir/exit_if.cc b/src/tint/ir/exit_if.cc
index 8b7de7f..16fac7f 100644
--- a/src/tint/ir/exit_if.cc
+++ b/src/tint/ir/exit_if.cc
@@ -14,13 +14,16 @@
#include "src/tint/ir/exit_if.h"
+#include <utility>
+
#include "src/tint/ir/if.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitIf);
namespace tint::ir {
-ExitIf::ExitIf(ir::If* i, utils::VectorRef<Value*> args) : Base(args), if_(i) {
+ExitIf::ExitIf(ir::If* i, utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), if_(i) {
TINT_ASSERT(IR, if_);
if_->AddUsage(this);
if_->Merge()->AddInboundBranch(this);
diff --git a/src/tint/ir/exit_if.h b/src/tint/ir/exit_if.h
index 9ba1421..153b714 100644
--- a/src/tint/ir/exit_if.h
+++ b/src/tint/ir/exit_if.h
@@ -31,7 +31,7 @@
/// Constructor
/// @param i the if being exited
/// @param args the branch arguments
- explicit ExitIf(ir::If* i, utils::VectorRef<Value*> args = {});
+ explicit ExitIf(ir::If* i, utils::VectorRef<Value*> args = utils::Empty);
~ExitIf() override;
/// @returns the if being exited
diff --git a/src/tint/ir/exit_loop.cc b/src/tint/ir/exit_loop.cc
index 5fe3910..0effcfc 100644
--- a/src/tint/ir/exit_loop.cc
+++ b/src/tint/ir/exit_loop.cc
@@ -14,13 +14,16 @@
#include "src/tint/ir/exit_loop.h"
+#include <utility>
+
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitLoop);
namespace tint::ir {
-ExitLoop::ExitLoop(ir::Loop* loop) : Base(utils::Empty), loop_(loop) {
+ExitLoop::ExitLoop(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), loop_(loop) {
TINT_ASSERT(IR, loop_);
loop_->AddUsage(this);
loop_->Merge()->AddInboundBranch(this);
diff --git a/src/tint/ir/exit_loop.h b/src/tint/ir/exit_loop.h
index 1df1119..4ef8110 100644
--- a/src/tint/ir/exit_loop.h
+++ b/src/tint/ir/exit_loop.h
@@ -30,7 +30,8 @@
public:
/// Constructor
/// @param loop the loop being exited
- explicit ExitLoop(ir::Loop* loop);
+ /// @param args the branch arguments
+ explicit ExitLoop(ir::Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
~ExitLoop() override;
/// @returns the loop being exited
diff --git a/src/tint/ir/exit_switch.cc b/src/tint/ir/exit_switch.cc
index ba6a178..e9679e5 100644
--- a/src/tint/ir/exit_switch.cc
+++ b/src/tint/ir/exit_switch.cc
@@ -14,13 +14,16 @@
#include "src/tint/ir/exit_switch.h"
+#include <utility>
+
#include "src/tint/ir/switch.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitSwitch);
namespace tint::ir {
-ExitSwitch::ExitSwitch(ir::Switch* sw) : Base(utils::Empty), switch_(sw) {
+ExitSwitch::ExitSwitch(ir::Switch* sw, utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), switch_(sw) {
TINT_ASSERT(IR, switch_);
switch_->AddUsage(this);
switch_->Merge()->AddInboundBranch(this);
diff --git a/src/tint/ir/exit_switch.h b/src/tint/ir/exit_switch.h
index 6b406fe..31a23e2 100644
--- a/src/tint/ir/exit_switch.h
+++ b/src/tint/ir/exit_switch.h
@@ -30,7 +30,8 @@
public:
/// Constructor
/// @param sw the switch being exited
- explicit ExitSwitch(ir::Switch* sw);
+ /// @param args the branch arguments
+ explicit ExitSwitch(ir::Switch* sw, utils::VectorRef<Value*> args = utils::Empty);
~ExitSwitch() override;
/// @returns the switch being exited
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 3dc52e1..4a442ee 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -642,7 +642,7 @@
{
ControlStackScope scope(this, loop_inst);
- current_block_ = loop_inst->Start();
+ current_block_ = loop_inst->Body();
// The loop doesn't use EmitBlock because it needs the scope stack to not get popped
// until after the continuing block.
@@ -691,7 +691,7 @@
{
ControlStackScope scope(this, loop_inst);
- current_block_ = loop_inst->Start();
+ current_block_ = loop_inst->Body();
// Emit the while condition into the Start().target of the loop
auto reg = EmitExpression(stmt->condition);
@@ -737,7 +737,7 @@
{
ControlStackScope scope(this, loop_inst);
- current_block_ = loop_inst->Start();
+ current_block_ = loop_inst->Body();
if (stmt->condition) {
// Emit the condition into the target target of the loop
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 8f5348b..404b6c3 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -342,7 +342,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(0u, flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(0u, flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
@@ -379,7 +379,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, loop_flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -437,7 +437,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, loop_flow->Merge()->InboundBranches().Length());
@@ -511,7 +511,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -563,7 +563,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(0u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
@@ -601,7 +601,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(0u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
@@ -632,7 +632,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(0u, loop_flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop_flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(2u, loop_flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -799,13 +799,13 @@
auto m = res.Move();
auto* flow = FindSingleValue<ir::Loop>(m);
- ASSERT_NE(flow->Start()->Branch(), nullptr);
- ASSERT_TRUE(flow->Start()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Start()->Branch()->As<ir::If>();
+ ASSERT_NE(flow->Body()->Branch(), nullptr);
+ ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
+ auto* if_flow = flow->Body()->Branch()->As<ir::If>();
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -860,13 +860,13 @@
auto m = res.Move();
auto* flow = FindSingleValue<ir::Loop>(m);
- ASSERT_NE(flow->Start()->Branch(), nullptr);
- ASSERT_TRUE(flow->Start()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Start()->Branch()->As<ir::If>();
+ ASSERT_NE(flow->Body()->Branch(), nullptr);
+ ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
+ auto* if_flow = flow->Body()->Branch()->As<ir::If>();
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -934,13 +934,13 @@
auto m = res.Move();
auto* flow = FindSingleValue<ir::Loop>(m);
- ASSERT_NE(flow->Start()->Branch(), nullptr);
- ASSERT_TRUE(flow->Start()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Start()->Branch()->As<ir::If>();
+ ASSERT_NE(flow->Body()->Branch(), nullptr);
+ ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
+ auto* if_flow = flow->Body()->Branch()->As<ir::If>();
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(2u, flow->Body()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(2u, flow->Merge()->InboundBranches().Length());
EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
@@ -962,7 +962,7 @@
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(0u, flow->Start()->InboundBranches().Length());
+ EXPECT_EQ(0u, flow->Body()->InboundBranches().Length());
EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
diff --git a/src/tint/ir/loop.cc b/src/tint/ir/loop.cc
index bd697bf..5b3e603 100644
--- a/src/tint/ir/loop.cc
+++ b/src/tint/ir/loop.cc
@@ -14,13 +14,18 @@
#include "src/tint/ir/loop.h"
+#include <utility>
+
TINT_INSTANTIATE_TYPEINFO(tint::ir::Loop);
namespace tint::ir {
-Loop::Loop(ir::Block* s, ir::Block* c, ir::Block* m)
- : Base(utils::Empty), start_(s), continuing_(c), merge_(m) {
- TINT_ASSERT(IR, start_);
+Loop::Loop(ir::Block* b,
+ ir::Block* c,
+ ir::Block* m,
+ utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), body_(b), continuing_(c), merge_(m) {
+ TINT_ASSERT(IR, body_);
TINT_ASSERT(IR, continuing_);
TINT_ASSERT(IR, merge_);
}
diff --git a/src/tint/ir/loop.h b/src/tint/ir/loop.h
index aadfd30..0ac9463 100644
--- a/src/tint/ir/loop.h
+++ b/src/tint/ir/loop.h
@@ -24,16 +24,17 @@
class Loop : public utils::Castable<Loop, Branch> {
public:
/// Constructor
- /// @param s the start block
+ /// @param b the body block
/// @param c the continuing block
/// @param m the merge block
- Loop(ir::Block* s, ir::Block* c, ir::Block* m);
+ /// @param args the branch arguments
+ Loop(ir::Block* b, ir::Block* c, ir::Block* m, utils::VectorRef<Value*> args = utils::Empty);
~Loop() override;
/// @returns the switch start branch
- const ir::Block* Start() const { return start_; }
+ const ir::Block* Body() const { return body_; }
/// @returns the switch start branch
- ir::Block* Start() { return start_; }
+ ir::Block* Body() { return body_; }
/// @returns the switch continuing branch
const ir::Block* Continuing() const { return continuing_; }
@@ -46,7 +47,7 @@
ir::Block* Merge() { return merge_; }
private:
- ir::Block* start_ = nullptr;
+ ir::Block* body_ = nullptr;
ir::Block* continuing_ = nullptr;
ir::Block* merge_ = nullptr;
};
diff --git a/src/tint/ir/next_iteration.cc b/src/tint/ir/next_iteration.cc
index 0c021eb..1e057d0 100644
--- a/src/tint/ir/next_iteration.cc
+++ b/src/tint/ir/next_iteration.cc
@@ -14,16 +14,19 @@
#include "src/tint/ir/next_iteration.h"
+#include <utility>
+
#include "src/tint/ir/loop.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::NextIteration);
namespace tint::ir {
-NextIteration::NextIteration(ir::Loop* loop) : Base(utils::Empty), loop_(loop) {
+NextIteration::NextIteration(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
+ : Base(std::move(args)), loop_(loop) {
TINT_ASSERT(IR, loop_);
loop_->AddUsage(this);
- loop_->Start()->AddInboundBranch(this);
+ loop_->Body()->AddInboundBranch(this);
}
NextIteration::~NextIteration() = default;
diff --git a/src/tint/ir/next_iteration.h b/src/tint/ir/next_iteration.h
index f1211e3..064ad3f 100644
--- a/src/tint/ir/next_iteration.h
+++ b/src/tint/ir/next_iteration.h
@@ -30,7 +30,8 @@
public:
/// Constructor
/// @param loop the loop being iterated
- explicit NextIteration(ir::Loop* loop);
+ /// @param args the branch arguments
+ explicit NextIteration(ir::Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
~NextIteration() override;
/// @returns the loop being iterated
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index c181f02..caf3e37 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -242,18 +242,25 @@
return count;
}
+/// Common data for constant conversion.
+struct ConvertContext {
+ ProgramBuilder& builder;
+ const Source& source;
+ bool use_runtime_semantics;
+};
+
+/// Converts the constant scalar value to the target type.
+/// @returns the converted scalar, or nullptr on error.
template <typename T>
-ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics) {
+const constant::ScalarBase* ScalarConvert(const constant::Scalar<T>* scalar,
+ const type::Type* target_ty,
+ ConvertContext& ctx) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == scalar->type) {
// If the types are identical, then no conversion is needed.
return scalar;
}
- return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ConstEval::Result {
+ return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> const constant::ScalarBase* {
// `value` is the source value.
// `FROM` is the source type.
// `TO` is the target type.
@@ -261,183 +268,219 @@
using FROM = T;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
- return builder.constants.Get<constant::Scalar<TO>>(target_ty,
- !scalar->IsPositiveZero());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ !scalar->IsPositiveZero());
} else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x]
- return builder.constants.Get<constant::Scalar<TO>>(target_ty,
- TO(scalar->value ? 1 : 0));
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO(scalar->value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(scalar->value)) {
// Conversion success
- return builder.constants.Get<constant::Scalar<TO>>(target_ty, conv.Get());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure
auto msg = OverflowErrorMessage(scalar->value, target_ty->FriendlyName());
- if (use_runtime_semantics) {
- builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
+ if (ctx.use_runtime_semantics) {
+ ctx.builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg,
+ ctx.source);
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty,
- TO::Highest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Highest());
}
} else {
- builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
- return utils::Failure;
+ ctx.builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, ctx.source);
+ return nullptr;
}
} else if constexpr (IsFloatingPoint<TO>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
auto msg = OverflowErrorMessage(scalar->value, target_ty->FriendlyName());
- if (use_runtime_semantics) {
- builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
+ if (ctx.use_runtime_semantics) {
+ ctx.builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg,
+ ctx.source);
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty,
- TO::Highest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Highest());
}
} else {
- builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
- return utils::Failure;
+ ctx.builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, ctx.source);
+ return nullptr;
}
} else if constexpr (IsFloatingPoint<FROM>) {
// [floating-point -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
- return builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Highest());
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ TO::Highest());
}
} else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable
// Static cast
- return builder.constants.Get<constant::Scalar<TO>>(target_ty,
- static_cast<TO>(scalar->value));
+ return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty,
+ static_cast<TO>(scalar->value));
}
- return nullptr; // Expression is not constant.
+ TINT_UNREACHABLE(Resolver, ctx.builder.Diagnostics()) << "Expression is not constant";
+ return nullptr;
});
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
-// Forward declare
-ConstEval::Result ConvertInternal(const constant::Value* c,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics);
+/// Converts the constant value to the target type.
+/// @returns the converted value, or nullptr on error.
+const constant::Value* ConvertInternal(const constant::Value* root_value,
+ const type::Type* root_target_ty,
+ ConvertContext& ctx) {
+ struct ActionConvert {
+ const constant::Value* value = nullptr;
+ const type::Type* target_ty = nullptr;
+ };
+ struct ActionBuildSplat {
+ size_t count = 0;
+ const type::Type* type = nullptr;
+ };
+ struct ActionBuildComposite {
+ size_t count = 0;
+ const type::Type* type = nullptr;
+ };
+ using Action = std::variant<ActionConvert, ActionBuildSplat, ActionBuildComposite>;
-ConstEval::Result CompositeConvert(const constant::Value* value,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics) {
- const size_t el_count = value->NumElements();
+ utils::Vector<Action, 8> pending{
+ ActionConvert{root_value, root_target_ty},
+ };
- // Convert each of the composite element types.
- utils::Vector<const constant::Value*, 4> conv_els;
- conv_els.Reserve(el_count);
+ utils::Vector<const constant::Value*, 32> value_stack;
- std::function<const type::Type*(size_t idx)> target_el_ty;
- if (auto* str = target_ty->As<type::Struct>()) {
- if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
- TINT_ICE(Resolver, builder.Diagnostics())
- << "const-eval conversion of structure has mismatched element counts";
- return utils::Failure;
+ while (!pending.IsEmpty()) {
+ auto next = pending.Pop();
+
+ if (auto* build = std::get_if<ActionBuildSplat>(&next)) {
+ TINT_ASSERT(Resolver, value_stack.Length() >= 1);
+ auto* el = value_stack.Pop();
+ value_stack.Push(ctx.builder.constants.Splat(build->type, el, build->count));
+ continue;
}
- target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
- } else {
- auto* el_ty = type::Type::ElementOf(target_ty);
- target_el_ty = [el_ty](size_t) { return el_ty; };
- }
- for (size_t i = 0; i < el_count; i++) {
- auto* el = value->Index(i);
- auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
- use_runtime_semantics);
- if (!conv_el) {
- return utils::Failure;
+ if (auto* build = std::get_if<ActionBuildComposite>(&next)) {
+ TINT_ASSERT(Resolver, value_stack.Length() >= build->count);
+ // Take build->count elements off the top of value_stack
+ // Note: The values are ordered with the first composite value at the top of the stack.
+ utils::Vector<const constant::Value*, 32> elements;
+ elements.Reserve(build->count);
+ for (size_t i = 0; i < build->count; i++) {
+ elements.Push(value_stack.Pop());
+ }
+ // Build the composite
+ value_stack.Push(ctx.builder.constants.Composite(build->type, std::move(elements)));
+ continue;
}
- if (!conv_el.Get()) {
+
+ auto* convert = std::get_if<ActionConvert>(&next);
+
+ bool ok = Switch(
+ convert->value,
+ [&](const constant::ScalarBase* scalar) {
+ auto* converted = Switch(
+ scalar,
+ [&](const constant::Scalar<tint::AFloat>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<tint::AInt>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<tint::u32>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<tint::i32>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<tint::f32>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<tint::f16>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ },
+ [&](const constant::Scalar<bool>* val) {
+ return ScalarConvert(val, convert->target_ty, ctx);
+ });
+ if (!converted) {
+ return false;
+ }
+ value_stack.Push(converted);
+ return true;
+ },
+ [&](const constant::Splat* splat) {
+ const type::Type* target_el_ty = nullptr;
+ if (auto* str = convert->target_ty->As<type::Struct>()) {
+ // Structure conversion.
+ auto members = str->Members();
+ target_el_ty = members[0]->Type();
+
+ // Structures can only be converted during materialization. The user cannot
+ // declare the target structure type, so each member type must be the same
+ // default materialization type.
+ for (size_t i = 1; i < members.Length(); i++) {
+ if (members[i]->Type() != target_el_ty) {
+ TINT_ICE(Resolver, ctx.builder.Diagnostics())
+ << "inconsistent target struct member types for SplatConvert";
+ return false;
+ }
+ }
+ } else {
+ target_el_ty = type::Type::ElementOf(convert->target_ty);
+ }
+
+ // Convert the single splatted element type.
+ pending.Push(ActionBuildSplat{splat->count, convert->target_ty});
+ pending.Push(ActionConvert{splat->el, target_el_ty});
+ return true;
+ },
+ [&](const constant::Composite* composite) {
+ const size_t el_count = composite->NumElements();
+
+ // Build the new composite from the converted element types.
+ pending.Push(ActionBuildComposite{el_count, convert->target_ty});
+
+ if (auto* str = convert->target_ty->As<type::Struct>()) {
+ if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
+ TINT_ICE(Resolver, ctx.builder.Diagnostics())
+ << "const-eval conversion of structure has mismatched element counts";
+ return false;
+ }
+ // Struct composites can have different types for each member.
+ auto members = str->Members();
+ for (size_t i = 0; i < el_count; i++) {
+ pending.Push(ActionConvert{composite->Index(i), members[i]->Type()});
+ }
+ } else {
+ // Non-struct composites have the same type for all elements.
+ auto* el_ty = type::Type::ElementOf(convert->target_ty);
+ for (size_t i = 0; i < el_count; i++) {
+ auto* el = composite->Index(i);
+ pending.Push(ActionConvert{el, el_ty});
+ }
+ }
+
+ return true;
+ });
+ if (!ok) {
return nullptr;
}
- conv_els.Push(conv_el.Get());
}
- return builder.constants.Composite(target_ty, std::move(conv_els));
-}
-ConstEval::Result SplatConvert(const constant::Splat* splat,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics) {
- const type::Type* target_el_ty = nullptr;
- if (auto* str = target_ty->As<type::Struct>()) {
- // Structure conversion.
- auto members = str->Members();
- target_el_ty = members[0]->Type();
-
- // Structures can only be converted during materialization. The user cannot declare the
- // target structure type, so each member type must be the same default materialization type.
- for (size_t i = 1; i < members.Length(); i++) {
- if (members[i]->Type() != target_el_ty) {
- TINT_ICE(Resolver, builder.Diagnostics())
- << "inconsistent target struct member types for SplatConvert";
- return utils::Failure;
- }
- }
- } else {
- target_el_ty = type::Type::ElementOf(target_ty);
- }
- // Convert the single splatted element type.
- auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics);
- if (!conv_el) {
- return utils::Failure;
- }
- if (!conv_el.Get()) {
- return nullptr;
- }
- return builder.constants.Splat(target_ty, conv_el.Get(), splat->count);
-}
-
-ConstEval::Result ConvertInternal(const constant::Value* c,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics) {
- return Switch(
- c,
- [&](const constant::Scalar<tint::AFloat>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<tint::AInt>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<tint::u32>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<tint::i32>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<tint::f32>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<tint::f16>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Scalar<bool>* val) {
- return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Splat* val) {
- return SplatConvert(val, builder, target_ty, source, use_runtime_semantics);
- },
- [&](const constant::Composite* val) {
- return CompositeConvert(val, builder, target_ty, source, use_runtime_semantics);
- });
+ TINT_ASSERT(Resolver, value_stack.Length() == 1);
+ return value_stack.Pop();
}
namespace detail {
@@ -3752,7 +3795,9 @@
if (value->Type() == target_ty) {
return value;
}
- return ConvertInternal(value, builder, target_ty, source, use_runtime_semantics_);
+ ConvertContext ctx{builder, source, use_runtime_semantics_};
+ auto* converted = ConvertInternal(value, target_ty, ctx);
+ return converted ? Result(converted) : utils::Failure;
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index 0782633..4a98045 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -16,6 +16,7 @@
#include "src/tint/type/abstract_float.h"
#include "src/tint/type/abstract_int.h"
+#include "src/tint/type/array.h"
#include "src/tint/type/bool.h"
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
@@ -123,4 +124,32 @@
const type::Matrix* Manager::mat4x4(const type::Type* inner) {
return mat(inner, 4, 4);
}
+
+const type::Array* Manager::array(const type::Type* elem_ty,
+ uint32_t count,
+ uint32_t stride /* = 0*/) {
+ if (stride == 0) {
+ stride = elem_ty->Align();
+ }
+ return Get<type::Array>(/* element type */ elem_ty,
+ /* element count */ Get<ConstantArrayCount>(count),
+ /* array alignment */ elem_ty->Align(),
+ /* array size */ count * stride,
+ /* element stride */ stride,
+ /* implicit stride */ elem_ty->Align());
+}
+
+const type::Array* Manager::runtime_array(const type::Type* elem_ty, uint32_t stride /* = 0 */) {
+ if (stride == 0) {
+ stride = elem_ty->Align();
+ }
+ return Get<type::Array>(
+ /* element type */ elem_ty,
+ /* element count */ Get<RuntimeArrayCount>(),
+ /* array alignment */ elem_ty->Align(),
+ /* array size */ stride,
+ /* element stride */ stride,
+ /* implicit stride */ elem_ty->Align());
+}
+
} // namespace tint::type
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index b4fa560..6492e53 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -26,6 +26,7 @@
namespace tint::type {
class AbstractFloat;
class AbstractInt;
+class Array;
class Bool;
class F16;
class F32;
@@ -182,6 +183,17 @@
/// @returns the matrix type
const type::Matrix* mat4x4(const type::Type* inner);
+ /// @param elem_ty the array element type
+ /// @param count the array element count
+ /// @param stride the optional array element stride
+ /// @returns the array type
+ const type::Array* array(const type::Type* elem_ty, uint32_t count, uint32_t stride = 0);
+
+ /// @param elem_ty the array element type
+ /// @param stride the optional array element stride
+ /// @returns the runtime array type
+ const type::Array* runtime_array(const type::Type* elem_ty, uint32_t stride = 0);
+
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types
diff --git a/src/tint/utils/cli.cc b/src/tint/utils/cli.cc
new file mode 100644
index 0000000..2d3fd52
--- /dev/null
+++ b/src/tint/utils/cli.cc
@@ -0,0 +1,182 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/utils/cli.h"
+
+#include <algorithm>
+#include <sstream>
+#include <utility>
+
+#include "src/tint/utils/hashset.h"
+#include "src/tint/utils/string.h"
+#include "src/tint/utils/transform.h"
+
+namespace tint::utils::cli {
+
+Option::~Option() = default;
+
+void OptionSet::ShowHelp(std::ostream& s_out) {
+ utils::Vector<const Option*, 32> sorted_options;
+ for (auto* opt : options.Objects()) {
+ sorted_options.Push(opt);
+ }
+ sorted_options.Sort([](const Option* a, const Option* b) { return a->Name() < b->Name(); });
+
+ struct CmdInfo {
+ std::string left;
+ std::string right;
+ };
+ utils::Vector<CmdInfo, 64> cmd_infos;
+
+ for (auto* opt : sorted_options) {
+ {
+ std::stringstream left, right;
+ left << "--" << opt->Name();
+ if (auto param = opt->Parameter(); !param.empty()) {
+ left << " <" << param << ">";
+ }
+ right << opt->Description();
+ if (auto def = opt->DefaultValue(); !def.empty()) {
+ right << "\ndefault: " << def;
+ }
+ cmd_infos.Push({left.str(), right.str()});
+ }
+ if (auto alias = opt->Alias(); !alias.empty()) {
+ std::stringstream left, right;
+ left << "--" << alias;
+ right << "alias for --" << opt->Name();
+ cmd_infos.Push({left.str(), right.str()});
+ }
+ if (auto sn = opt->ShortName(); !sn.empty()) {
+ std::stringstream left, right;
+ left << " -" << sn;
+ right << "short name for --" << opt->Name();
+ cmd_infos.Push({left.str(), right.str()});
+ }
+ }
+
+ const size_t kMaxRightOffset = 30;
+
+ // Measure
+ size_t left_width = 0;
+ for (auto& cmd_info : cmd_infos) {
+ for (auto line : utils::Split(cmd_info.left, "\n")) {
+ if (line.length() < kMaxRightOffset) {
+ left_width = std::max(left_width, line.length());
+ }
+ }
+ }
+
+ // Print
+ left_width = std::min(left_width, kMaxRightOffset);
+
+ auto pad = [&](size_t n) {
+ while (n--) {
+ s_out << " ";
+ }
+ };
+
+ for (auto& cmd_info : cmd_infos) {
+ auto left_lines = utils::Split(cmd_info.left, "\n");
+ auto right_lines = utils::Split(cmd_info.right, "\n");
+
+ size_t num_lines = std::max(left_lines.Length(), right_lines.Length());
+ for (size_t i = 0; i < num_lines; i++) {
+ bool has_left = (i < left_lines.Length()) && !left_lines[i].empty();
+ bool has_right = (i < right_lines.Length()) && !right_lines[i].empty();
+ if (has_left) {
+ s_out << left_lines[i];
+ if (has_right) {
+ if (left_lines[i].length() > left_width) {
+ // Left exceeds column width.
+ // Insert a new line and indent to the right
+ s_out << std::endl;
+ pad(left_width);
+ } else {
+ pad(left_width - left_lines[i].length());
+ }
+ }
+ } else if (has_right) {
+ pad(left_width);
+ }
+ if (has_right) {
+ s_out << " " << right_lines[i];
+ }
+ s_out << std::endl;
+ }
+ }
+}
+
+Result<OptionSet::Unconsumed> OptionSet::Parse(std::ostream& s_err,
+ utils::VectorRef<std::string_view> arguments_raw) {
+ // Build a map of name to option, and set defaults
+ utils::Hashmap<std::string, Option*, 32> options_by_name;
+ for (auto* opt : options.Objects()) {
+ opt->SetDefault();
+ for (auto name : {opt->Name(), opt->Alias(), opt->ShortName()}) {
+ if (!name.empty() && !options_by_name.Add(name, opt)) {
+ s_err << "multiple options with name '" << name << "'" << std::endl;
+ return Failure;
+ }
+ }
+ }
+
+ // Canonicalize arguments by splitting '--foo=x' into '--foo' 'x'.
+ std::deque<std::string_view> arguments;
+ for (auto arg : arguments_raw) {
+ if (HasPrefix(arg, "-")) {
+ if (auto eq_idx = arg.find("="); eq_idx != std::string_view::npos) {
+ arguments.push_back(arg.substr(0, eq_idx));
+ arguments.push_back(arg.substr(eq_idx + 1));
+ continue;
+ }
+ }
+ arguments.push_back(arg);
+ }
+
+ utils::Hashset<Option*, 8> options_parsed;
+
+ Unconsumed unconsumed;
+ while (!arguments.empty()) {
+ auto arg = std::move(arguments.front());
+ arguments.pop_front();
+ auto name = TrimLeft(arg, [](char c) { return c == '-'; });
+ if (arg == name || name.length() == 0) {
+ unconsumed.Push(arg);
+ continue;
+ }
+ if (auto opt = options_by_name.Find(name)) {
+ if (auto err = (*opt)->Parse(arguments); !err.empty()) {
+ s_err << err << std::endl;
+ return Failure;
+ }
+ } else {
+ s_err << "unknown flag: " << arg << std::endl;
+ auto names = options_by_name.Keys();
+ auto alternatives =
+ Transform(names, [&](const std::string& s) { return std::string_view(s); });
+ utils::StringStream ss;
+ utils::SuggestAlternativeOptions opts;
+ opts.prefix = "--";
+ opts.list_possible_values = false;
+ SuggestAlternatives(arg, alternatives.Slice(), ss, opts);
+ s_err << ss.str();
+ return Failure;
+ }
+ }
+
+ return unconsumed;
+}
+
+} // namespace tint::utils::cli
diff --git a/src/tint/utils/cli.h b/src/tint/utils/cli.h
new file mode 100644
index 0000000..c1c3030
--- /dev/null
+++ b/src/tint/utils/cli.h
@@ -0,0 +1,410 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_UTILS_CLI_H_
+#define SRC_TINT_UTILS_CLI_H_
+
+#include <deque>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "src/tint/utils/block_allocator.h"
+#include "src/tint/utils/compiler_macros.h"
+#include "src/tint/utils/parse_num.h"
+#include "src/tint/utils/result.h"
+#include "src/tint/utils/string.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::utils::cli {
+
+/// Alias is a fluent-constructor helper for Options
+struct Alias {
+ /// The alias to apply to an Option
+ std::string value;
+
+ /// @param option the option to apply the alias to
+ template <typename T>
+ void Apply(T& option) {
+ option.alias = value;
+ }
+};
+
+/// ShortName is a fluent-constructor helper for Options
+struct ShortName {
+ /// The short-name to apply to an Option
+ std::string value;
+
+ /// @param option the option to apply the short name to
+ template <typename T>
+ void Apply(T& option) {
+ option.short_name = value;
+ }
+};
+
+/// Parameter is a fluent-constructor helper for Options
+struct Parameter {
+ /// The parameter name to apply to an Option
+ std::string value;
+
+ /// @param option the option to apply the parameter name to
+ template <typename T>
+ void Apply(T& option) {
+ option.parameter = value;
+ }
+};
+
+/// Default is a fluent-constructor helper for Options
+template <typename T>
+struct Default {
+ /// The default value to apply to an Option
+ T value;
+
+ /// @param option the option to apply the default value to
+ template <typename O>
+ void Apply(O& option) {
+ option.default_value = value;
+ }
+};
+
+/// Deduction guide for Default
+template <typename T>
+Default(T) -> Default<T>;
+
+/// Option is the base class for all command line options
+class Option {
+ public:
+ /// An alias to std::string, used to hold error messages.
+ using Error = std::string;
+
+ /// Destructor
+ virtual ~Option();
+
+ /// @return the name of the option, without any leading hyphens.
+ /// Example: 'help'
+ virtual std::string Name() const = 0;
+
+ /// @return the alias name of the option, without any leading hyphens. (optional)
+ /// Example: 'flag'
+ virtual std::string Alias() const = 0;
+
+ /// @return the shorter name of the option, without any leading hyphens. (optional)
+ /// Example: 'h'
+ virtual std::string ShortName() const = 0;
+
+ /// @return a string describing the parameter that the option expects.
+ /// Empty represents no expected parameter.
+ virtual std::string Parameter() const = 0;
+
+ /// @return a description of the option.
+ /// Example: 'shows this message'
+ virtual std::string Description() const = 0;
+
+ /// @return the default value of the option, or an empty string if there is no default value.
+ virtual std::string DefaultValue() const = 0;
+
+ /// Sets the option value to the default (called before arguments are parsed)
+ virtual void SetDefault() = 0;
+
+ /// Parses the option's arguments from the list of command line arguments, removing the consumed
+ /// arguments before returning. @p arguments will have already had the option's name consumed
+ /// before calling.
+ /// @param arguments the queue of unprocessed arguments. Parse() may take from the front of @p
+ /// arguments.
+ /// @return empty Error if successfully parsed, otherwise an error string.
+ virtual Error Parse(std::deque<std::string_view>& arguments) = 0;
+
+ protected:
+ /// An empty string, used to represent no-error.
+ static constexpr const char* Success = "";
+
+ /// @param expected the expected value(s) for the option
+ /// @return an Error message for a missing argument
+ Error ErrMissingArgument(std::string expected) const {
+ Error err = "missing value for option '--" + Name() + "'";
+ if (!expected.empty()) {
+ err += "Expected: " + expected;
+ }
+ return err;
+ }
+
+ /// @param got the argument value provided
+ /// @param reason the reason the argument is invalid (optional)
+ /// @return an Error message for an invalid argument
+ Error ErrInvalidArgument(std::string_view got, std::string reason) const {
+ Error err = "invalid value '" + std::string(got) + "' for option '--" + Name() + "'";
+ if (!reason.empty()) {
+ err += "\n" + reason;
+ }
+ return err;
+ }
+};
+
+/// OptionSet is a set of Options, which can parse the command line arguments.
+class OptionSet {
+ public:
+ /// Unconsumed is a list of unconsumed command line arguments
+ using Unconsumed = utils::Vector<std::string_view, 8>;
+
+ /// Constructs and returns a new Option to be owned by the OptionSet
+ /// @tparam T the Option type
+ /// @tparam ARGS the constructor argument types
+ /// @param args the constructor arguments
+ /// @return the constructed Option
+ template <typename T, typename... ARGS>
+ T& Add(ARGS&&... args) {
+ return *options.Create<T>(std::forward<ARGS>(args)...);
+ }
+
+ /// Prints to @p out the description of all the command line options.
+ /// @param out the output stream
+ void ShowHelp(std::ostream& out);
+
+ /// Parses all the options in @p options.
+ /// @param err the error stream
+ /// @param arguments the command line arguments, excluding the initial executable name
+ /// @return a Result holding a list of arguments that were not consumed as options
+ Result<Unconsumed> Parse(std::ostream& err, utils::VectorRef<std::string_view> arguments);
+
+ private:
+ /// The list of options to parse
+ utils::BlockAllocator<Option, 1024> options;
+};
+
+/// ValueOption is an option that accepts a single value
+template <typename T>
+class ValueOption : public Option {
+ static constexpr bool is_bool = std::is_same_v<T, bool>;
+ static constexpr bool is_number =
+ !is_bool && (std::is_integral_v<T> || std::is_floating_point_v<T>);
+ static constexpr bool is_string = std::is_same_v<T, std::string>;
+ static_assert(is_bool || is_number || is_string, "unsupported data type");
+
+ public:
+ /// The name of the option, without any leading hyphens.
+ std::string name;
+ /// The alias name of the option, without any leading hyphens.
+ std::string alias;
+ /// The shorter name of the option, without any leading hyphens.
+ std::string short_name;
+ /// A description of the option.
+ std::string description;
+ /// The default value.
+ std::optional<T> default_value;
+ /// The option value. Populated with Parse().
+ std::optional<T> value;
+ /// A string describing the name of the option's value.
+ std::string parameter = "value";
+
+ /// Constructor
+ ValueOption() = default;
+
+ /// Constructor
+ /// @param option_name the option name
+ /// @param option_description the option description
+ /// @param settings a number of fluent-constructor values that configure the option
+ /// @see ShortName, Parameter, Default
+ template <typename... SETTINGS>
+ ValueOption(std::string option_name, std::string option_description, SETTINGS&&... settings)
+ : name(std::move(option_name)), description(std::move(option_description)) {
+ (settings.Apply(*this), ...);
+ }
+
+ std::string Name() const override { return name; }
+
+ std::string Alias() const override { return alias; }
+
+ std::string ShortName() const override { return short_name; }
+
+ std::string Parameter() const override { return parameter; }
+
+ std::string Description() const override { return description; }
+
+ std::string DefaultValue() const override {
+ return default_value.has_value() ? ToString(*default_value) : "";
+ }
+
+ void SetDefault() override { value = default_value; }
+
+ Error Parse(std::deque<std::string_view>& arguments) override {
+ TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+
+ if (arguments.empty()) {
+ if constexpr (is_bool) {
+ // Treat as flag (--blah)
+ value = true;
+ return Success;
+ } else {
+ return ErrMissingArgument(parameter);
+ }
+ }
+
+ auto arg = arguments.front();
+
+ if constexpr (is_number) {
+ auto result = ParseNumber<T>(arg);
+ if (result) {
+ value = result.Get();
+ arguments.pop_front();
+ return Success;
+ }
+ if (result.Failure() == ParseNumberError::kResultOutOfRange) {
+ return ErrInvalidArgument(arg, "value out of range");
+ }
+ return ErrInvalidArgument(arg, "failed to parse value");
+ } else if constexpr (is_string) {
+ value = arg;
+ arguments.pop_front();
+ return Success;
+ } else if constexpr (is_bool) {
+ if (arg == "true") {
+ value = true;
+ arguments.pop_front();
+ return Success;
+ }
+ if (arg == "false") {
+ value = false;
+ arguments.pop_front();
+ return Success;
+ }
+ // Next argument is assumed to be another option, or unconsumed argument.
+ // Treat as flag (--blah)
+ value = true;
+ return Success;
+ }
+
+ TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+ }
+};
+
+/// BoolOption is an alias to ValueOption<bool>
+using BoolOption = ValueOption<bool>;
+
+/// StringOption is an alias to ValueOption<std::string>
+using StringOption = ValueOption<std::string>;
+
+/// EnumName is a pair of enum value and name.
+/// @tparam ENUM the enum type
+template <typename ENUM>
+struct EnumName {
+ /// Constructor
+ EnumName() = default;
+
+ /// Constructor
+ /// @param v the enum value
+ /// @param n the name of the enum value
+ EnumName(ENUM v, std::string n) : value(v), name(std::move(n)) {}
+
+ /// the enum value
+ ENUM value;
+ /// the name of the enum value
+ std::string name;
+};
+
+/// Deduction guide for EnumName
+template <typename ENUM>
+EnumName(ENUM, std::string) -> EnumName<ENUM>;
+
+/// EnumOption is an option that accepts an enumerator of values
+template <typename ENUM>
+class EnumOption : public Option {
+ public:
+ /// The name of the option, without any leading hyphens.
+ std::string name;
+ /// The alias name of the option, without any leading hyphens.
+ std::string alias;
+ /// The shorter name of the option, without any leading hyphens.
+ std::string short_name;
+ /// A description of the option.
+ std::string description;
+ /// The enum options as a pair of enum value to name
+ utils::Vector<EnumName<ENUM>, 8> enum_names;
+ /// The default value.
+ std::optional<ENUM> default_value;
+ /// The option value. Populated with Parse().
+ std::optional<ENUM> value;
+
+ /// Constructor
+ EnumOption() = default;
+
+ /// Constructor
+ /// @param option_name the option name
+ /// @param option_description the option description
+ /// @param names The enum options as a pair of enum value to name
+ /// @param settings a number of fluent-constructor values that configure the option
+ /// @see ShortName, Parameter, Default
+ template <typename... SETTINGS>
+ EnumOption(std::string option_name,
+ std::string option_description,
+ utils::VectorRef<EnumName<ENUM>> names,
+ SETTINGS&&... settings)
+ : name(std::move(option_name)),
+ description(std::move(option_description)),
+ enum_names(std::move(names)) {
+ (settings.Apply(*this), ...);
+ }
+
+ std::string Name() const override { return name; }
+
+ std::string ShortName() const override { return short_name; }
+
+ std::string Alias() const override { return alias; }
+
+ std::string Parameter() const override { return PossibleValues("|"); }
+
+ std::string Description() const override { return description; }
+
+ std::string DefaultValue() const override {
+ for (auto& enum_name : enum_names) {
+ if (enum_name.value == default_value) {
+ return enum_name.name;
+ }
+ }
+ return "";
+ }
+
+ void SetDefault() override { value = default_value; }
+
+ Error Parse(std::deque<std::string_view>& arguments) override {
+ if (arguments.empty()) {
+ return ErrMissingArgument("one of: " + PossibleValues(", "));
+ }
+ auto& arg = arguments.front();
+ for (auto& enum_name : enum_names) {
+ if (enum_name.name == arg) {
+ value = enum_name.value;
+ arguments.pop_front();
+ return Success;
+ }
+ }
+ return ErrInvalidArgument(arg, "Must be one of: " + PossibleValues(", "));
+ }
+
+ /// @param delimiter the delimiter between each enum option
+ /// @returns the accepted enum names delimited with @p delimiter
+ std::string PossibleValues(std::string delimiter) const {
+ std::string out;
+ for (auto& enum_name : enum_names) {
+ if (!out.empty()) {
+ out += delimiter;
+ }
+ out += enum_name.name;
+ }
+ return out;
+ }
+};
+
+} // namespace tint::utils::cli
+
+#endif // SRC_TINT_UTILS_CLI_H_
diff --git a/src/tint/utils/cli_test.cc b/src/tint/utils/cli_test.cc
new file mode 100644
index 0000000..cabbf14
--- /dev/null
+++ b/src/tint/utils/cli_test.cc
@@ -0,0 +1,299 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/utils/cli.h"
+
+#include <sstream>
+
+#include "gmock/gmock.h"
+#include "src/tint/utils/string.h"
+
+#include "src/tint/utils/transform.h" // Used by ToStringList()
+
+namespace tint::utils::cli {
+namespace {
+
+// Workaround for https://github.com/google/googletest/issues/3081
+// Remove when using C++20
+template <size_t N>
+utils::Vector<std::string, N> ToStringList(const utils::Vector<std::string_view, N>& views) {
+ return Transform(views, [](std::string_view view) { return std::string(view); });
+}
+
+using CLITest = testing::Test;
+
+TEST_F(CLITest, ShowHelp_ValueWithParameter) {
+ OptionSet opts;
+ opts.Add<ValueOption<int>>("my_option", "sets the awesome value");
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--my_option <value> sets the awesome value
+)");
+}
+
+TEST_F(CLITest, ShowHelp_ValueWithAlias) {
+ OptionSet opts;
+ opts.Add<ValueOption<int>>("my_option", "sets the awesome value", Alias{"alias"});
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--my_option <value> sets the awesome value
+--alias alias for --my_option
+)");
+}
+TEST_F(CLITest, ShowHelp_ValueWithShortName) {
+ OptionSet opts;
+ opts.Add<ValueOption<int>>("my_option", "sets the awesome value", ShortName{"a"});
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--my_option <value> sets the awesome value
+ -a short name for --my_option
+)");
+}
+
+TEST_F(CLITest, ShowHelp_MultilineDesc) {
+ OptionSet opts;
+ opts.Add<ValueOption<int>>("an-option", R"(this is a
+multi-line description
+for an option
+)");
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--an-option <value> this is a
+ multi-line description
+ for an option
+
+)");
+}
+
+TEST_F(CLITest, ShowHelp_LongName) {
+ OptionSet opts;
+ opts.Add<ValueOption<int>>("an-option-with-a-really-really-long-name",
+ "this is an option that has a silly long name", ShortName{"a"});
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--an-option-with-a-really-really-long-name <value>
+ this is an option that has a silly long name
+ -a short name for --an-option-with-a-really-really-long-name
+)");
+}
+
+TEST_F(CLITest, ShowHelp_EnumValue) {
+ enum class E { X, Y, Z };
+
+ OptionSet opts;
+ opts.Add<EnumOption<E>>("my_enum_option", "sets the awesome value",
+ utils::Vector{
+ EnumName(E::X, "X"),
+ EnumName(E::Y, "Y"),
+ EnumName(E::Z, "Z"),
+ });
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--my_enum_option <X|Y|Z> sets the awesome value
+)");
+}
+
+TEST_F(CLITest, ShowHelp_MixedValues) {
+ enum class E { X, Y, Z };
+
+ OptionSet opts;
+
+ opts.Add<ValueOption<int>>("option-a", "an integer");
+ opts.Add<BoolOption>("option-b", "a boolean");
+ opts.Add<EnumOption<E>>("option-c", "sets the awesome value",
+ utils::Vector{
+ EnumName(E::X, "X"),
+ EnumName(E::Y, "Y"),
+ EnumName(E::Z, "Z"),
+ });
+
+ std::stringstream out;
+ out << std::endl;
+ opts.ShowHelp(out);
+ EXPECT_EQ(out.str(), R"(
+--option-a <value> an integer
+--option-b <value> a boolean
+--option-c <X|Y|Z> sets the awesome value
+)");
+}
+
+TEST_F(CLITest, ParseBool_Flag) {
+ OptionSet opts;
+ auto& opt = opts.Add<BoolOption>("my_option", "a boolean value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option unconsumed", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre("unconsumed"));
+ EXPECT_EQ(opt.value, true);
+}
+
+TEST_F(CLITest, ParseBool_ExplicitTrue) {
+ OptionSet opts;
+ auto& opt = opts.Add<BoolOption>("my_option", "a boolean value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option true", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, true);
+}
+
+TEST_F(CLITest, ParseBool_ExplicitFalse) {
+ OptionSet opts;
+ auto& opt = opts.Add<BoolOption>("my_option", "a boolean value", Default{true});
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option false", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, false);
+}
+
+TEST_F(CLITest, ParseInt) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<int>>("my_option", "an integer value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option 42", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, 42);
+}
+
+TEST_F(CLITest, ParseUint64) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<uint64_t>>("my_option", "a uint64_t value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option 1000000", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, 1000000);
+}
+
+TEST_F(CLITest, ParseFloat) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<float>>("my_option", "a float value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option 1.25", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, 1.25f);
+}
+
+TEST_F(CLITest, ParseString) {
+ OptionSet opts;
+ auto& opt = opts.Add<StringOption>("my_option", "a string value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option blah", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, "blah");
+}
+
+TEST_F(CLITest, ParseEnum) {
+ enum class E { X, Y, Z };
+
+ OptionSet opts;
+ auto& opt = opts.Add<EnumOption<E>>("my_option", "sets the awesome value",
+ utils::Vector{
+ EnumName(E::X, "X"),
+ EnumName(E::Y, "Y"),
+ EnumName(E::Z, "Z"),
+ });
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option Y", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, E::Y);
+}
+
+TEST_F(CLITest, ParseShortName) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<int>>("my_option", "an integer value", ShortName{"o"});
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("-o 42", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, 42);
+}
+
+TEST_F(CLITest, ParseUnconsumed) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<int32_t>>("my_option", "a int32_t value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("abc --my_option -123 def", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre("abc", "def"));
+ EXPECT_EQ(opt.value, -123);
+}
+
+TEST_F(CLITest, ParseUsingEquals) {
+ OptionSet opts;
+ auto& opt = opts.Add<ValueOption<int>>("my_option", "an int value");
+
+ std::stringstream err;
+ auto res = opts.Parse(err, Split("--my_option=123", " "));
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_THAT(ToStringList(res.Get()), testing::ElementsAre());
+ EXPECT_EQ(opt.value, 123);
+}
+
+TEST_F(CLITest, SetValueToDefault) {
+ OptionSet opts;
+ auto& opt = opts.Add<BoolOption>("my_option", "a boolean value", Default{true});
+
+ std::stringstream err;
+ auto res = opts.Parse(err, utils::Empty);
+ ASSERT_TRUE(res) << err.str();
+ EXPECT_TRUE(err.str().empty());
+ EXPECT_EQ(opt.value, true);
+}
+
+} // namespace
+} // namespace tint::utils::cli
diff --git a/src/tint/utils/parse_num.h b/src/tint/utils/parse_num.h
index 6d4fcb4..7c4ff7e 100644
--- a/src/tint/utils/parse_num.h
+++ b/src/tint/utils/parse_num.h
@@ -18,6 +18,7 @@
#include <optional>
#include <string>
+#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
namespace tint::utils {
@@ -78,6 +79,8 @@
/// @returns the string @p str parsed as a uint8_t
Result<uint8_t, ParseNumberError> ParseUint8(std::string_view str);
+TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+
/// @param str the string
/// @returns the string @p str parsed as a the number @p T
template <typename T>
@@ -121,6 +124,8 @@
return ParseNumberError::kUnparsable;
}
+TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+
} // namespace tint::utils
#endif // SRC_TINT_UTILS_PARSE_NUM_H_
diff --git a/src/tint/utils/string_test.cc b/src/tint/utils/string_test.cc
index 676c341..63f72a6 100644
--- a/src/tint/utils/string_test.cc
+++ b/src/tint/utils/string_test.cc
@@ -17,9 +17,18 @@
#include "gmock/gmock.h"
#include "src/tint/utils/string_stream.h"
+#include "src/tint/utils/transform.h" // Used by ToStringList()
+
namespace tint::utils {
namespace {
+// Workaround for https://github.com/google/googletest/issues/3081
+// Remove when using C++20
+template <size_t N>
+utils::Vector<std::string, N> ToStringList(const utils::Vector<std::string_view, N>& views) {
+ return Transform(views, [](std::string_view view) { return std::string(view); });
+}
+
TEST(StringTest, ReplaceAll) {
EXPECT_EQ("xybbcc", ReplaceAll("aabbcc", "aa", "xy"));
EXPECT_EQ("aaxycc", ReplaceAll("aabbcc", "bb", "xy"));
@@ -176,16 +185,15 @@
EXPECT_EQ("'meow'", Quote("meow"));
}
-#if 0 // Enable when moved to C++20 (https://github.com/google/googletest/issues/3081)
TEST(StringTest, Split) {
- EXPECT_THAT(Split("", ","), testing::ElementsAre(""));
- EXPECT_THAT(Split("cat", ","), testing::ElementsAre("cat"));
- EXPECT_THAT(Split("cat,", ","), testing::ElementsAre("cat", ""));
- EXPECT_THAT(Split(",cat", ","), testing::ElementsAre("", "cat"));
- EXPECT_THAT(Split("cat,dog,fish", ","), testing::ElementsAre("cat", "dog", "fish"));
- EXPECT_THAT(Split("catdogfish", "dog"), testing::ElementsAre("cat", "fish"));
+ EXPECT_THAT(ToStringList(Split("", ",")), testing::ElementsAre(""));
+ EXPECT_THAT(ToStringList(Split("cat", ",")), testing::ElementsAre("cat"));
+ EXPECT_THAT(ToStringList(Split("cat,", ",")), testing::ElementsAre("cat", ""));
+ EXPECT_THAT(ToStringList(Split(",cat", ",")), testing::ElementsAre("", "cat"));
+ EXPECT_THAT(ToStringList(Split("cat,dog,fish", ",")),
+ testing::ElementsAre("cat", "dog", "fish"));
+ EXPECT_THAT(ToStringList(Split("catdogfish", "dog")), testing::ElementsAre("cat", "fish"));
}
-#endif
TEST(StringTest, Join) {
EXPECT_EQ(Join(utils::Vector<int, 1>{}, ","), "");
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index a32f5b7..8ae01c8 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -18,6 +18,7 @@
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
+#include "src/tint/constant/scalar.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/break_if.h"
@@ -39,6 +40,7 @@
#include "src/tint/ir/var.h"
#include "src/tint/switch.h"
#include "src/tint/transform/manager.h"
+#include "src/tint/type/array.h"
#include "src/tint/type/bool.h"
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
@@ -100,11 +102,9 @@
// TODO(crbug.com/tint/1906): Emit extensions.
- // TODO(crbug.com/tint/1906): Emit variables.
- (void)zero_init_workgroup_memory_;
+ // Emit module-scope declarations.
if (ir_->root_block) {
- TINT_ICE(Writer, diagnostics_) << "root block is unimplemented";
- return false;
+ EmitRootBlock(ir_->root_block);
}
// Emit functions.
@@ -174,6 +174,14 @@
});
}
+uint32_t GeneratorImplIr::ConstantNull(const type::Type* type) {
+ return constant_nulls_.GetOrCreate(type, [&]() {
+ auto id = module_.NextId();
+ module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
+ return id;
+ });
+}
+
uint32_t GeneratorImplIr::Type(const type::Type* ty) {
return types_.GetOrCreate(ty, [&]() {
auto id = module_.NextId();
@@ -200,6 +208,18 @@
module_.PushType(spv::Op::OpTypeMatrix,
{id, Type(mat->ColumnType()), mat->columns()});
},
+ [&](const type::Array* arr) {
+ if (arr->ConstantCount()) {
+ auto* count = ir_->constant_values.Get(u32(arr->ConstantCount().value()));
+ module_.PushType(spv::Op::OpTypeArray,
+ {id, Type(arr->ElemType()), Constant(count)});
+ } else {
+ TINT_ASSERT(Writer, arr->Count()->Is<type::RuntimeArrayCount>());
+ module_.PushType(spv::Op::OpTypeRuntimeArray, {id, Type(arr->ElemType())});
+ }
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationArrayStride), arr->Stride()});
+ },
[&](const type::Pointer* ptr) {
module_.PushType(
spv::Op::OpTypePointer,
@@ -320,6 +340,20 @@
{U32Operand(stage), id, ir_->NameOf(func).Name()});
}
+void GeneratorImplIr::EmitRootBlock(const ir::Block* root_block) {
+ for (auto* inst : root_block->Instructions()) {
+ auto result = Switch(
+ inst, //
+ [&](const ir::Var* v) { return EmitVar(v); },
+ [&](Default) {
+ TINT_ICE(Writer, diagnostics_)
+ << "unimplemented root block instruction: " << inst->TypeInfo().name;
+ return 0u;
+ });
+ values_.Add(inst, result);
+ }
+}
+
void GeneratorImplIr::EmitBlock(const ir::Block* block) {
// Emit the label.
// Skip if this is the function's entry block, as it will be emitted by the function object.
@@ -334,6 +368,22 @@
return;
}
+ // Emit Phi nodes for all the incoming block parameters
+ for (size_t param_idx = 0; param_idx < block->Params().Length(); param_idx++) {
+ auto* param = block->Params()[param_idx];
+ auto id = module_.NextId();
+ values_.Add(param, id);
+ OperandList ops{Type(param->Type()), id};
+
+ for (auto* incoming : block->InboundBranches()) {
+ auto* arg = incoming->Args()[param_idx];
+ ops.push_back(Value(arg));
+ ops.push_back(Label(incoming->Block()));
+ }
+
+ current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
+ }
+
// Emit the instructions.
for (auto* inst : block->Instructions()) {
auto result = Switch(
@@ -391,7 +441,7 @@
{
Value(breakif->Condition()),
Label(breakif->Loop()->Merge()),
- Label(breakif->Loop()->Start()),
+ Label(breakif->Loop()->Body()),
});
},
[&](const ir::Continue* cont) {
@@ -407,7 +457,7 @@
current_function_.push_inst(spv::Op::OpBranch, {Label(swtch->Switch()->Merge())});
},
[&](const ir::NextIteration* loop) {
- current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Start())});
+ current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Body())});
},
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unimplemented branch: " << b->TypeInfo().name;
@@ -421,16 +471,17 @@
// Generate labels for the blocks. We emit the true or false block if it:
// 1. contains instructions other then the branch, or
- // 2. branches somewhere other then the Merge().
+ // 2. branches somewhere other then the Merge(), or
+ // 3. the merge has input parameters
// Otherwise we skip them and branch straight to the merge block.
uint32_t merge_label = Label(merge_block);
uint32_t true_label = merge_label;
uint32_t false_label = merge_label;
- if (true_block->Instructions().Length() > 1 ||
+ if (true_block->Instructions().Length() > 1 || !merge_block->Params().IsEmpty() ||
(true_block->HasBranchTarget() && !true_block->Branch()->Is<ir::ExitIf>())) {
true_label = Label(true_block);
}
- if (false_block->Instructions().Length() > 1 ||
+ if (false_block->Instructions().Length() > 1 || !merge_block->Params().IsEmpty() ||
(false_block->HasBranchTarget() && !false_block->Branch()->Is<ir::ExitIf>())) {
false_label = Label(false_block);
}
@@ -631,7 +682,7 @@
void GeneratorImplIr::EmitLoop(const ir::Loop* loop) {
auto header_label = module_.NextId();
- auto body_label = Label(loop->Start());
+ auto body_label = Label(loop->Body());
auto continuing_label = Label(loop->Continuing());
auto merge_label = Label(loop->Merge());
@@ -643,11 +694,11 @@
current_function_.push_inst(spv::Op::OpBranch, {body_label});
// Emit the loop body.
- EmitBlock(loop->Start());
+ EmitBlock(loop->Body());
// Emit the loop continuing block.
// The back-edge needs to go to the loop header, so update the label for the start block.
- block_labels_.Replace(loop->Start(), header_label);
+ block_labels_.Replace(loop->Body(), header_label);
if (loop->Continuing()->HasBranchTarget()) {
EmitBlock(loop->Continuing());
} else {
@@ -719,16 +770,41 @@
TINT_ASSERT(Writer, ptr);
auto ty = Type(ptr);
- if (ptr->AddressSpace() == builtin::AddressSpace::kFunction) {
- TINT_ASSERT(Writer, current_function_);
- current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)});
- if (var->Initializer()) {
- current_function_.push_inst(spv::Op::OpStore, {id, Value(var->Initializer())});
+ switch (ptr->AddressSpace()) {
+ case builtin::AddressSpace::kFunction: {
+ TINT_ASSERT(Writer, current_function_);
+ current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)});
+ if (var->Initializer()) {
+ current_function_.push_inst(spv::Op::OpStore, {id, Value(var->Initializer())});
+ }
+ break;
}
- } else {
- TINT_ICE(Writer, diagnostics_)
- << "unimplemented variable address space " << ptr->AddressSpace();
- return 0u;
+ case builtin::AddressSpace::kPrivate: {
+ TINT_ASSERT(Writer, !current_function_);
+ OperandList operands = {ty, id, U32Operand(SpvStorageClassPrivate)};
+ if (var->Initializer()) {
+ TINT_ASSERT(Writer, var->Initializer()->Is<ir::Constant>());
+ operands.push_back(Value(var->Initializer()));
+ }
+ module_.PushType(spv::Op::OpVariable, operands);
+ break;
+ }
+ case builtin::AddressSpace::kWorkgroup: {
+ TINT_ASSERT(Writer, !current_function_);
+ OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
+ if (zero_init_workgroup_memory_) {
+ // If requested, use the VK_KHR_zero_initialize_workgroup_memory to zero-initialize
+ // the workgroup variable using an null constant initializer.
+ operands.push_back(ConstantNull(ptr->StoreType()));
+ }
+ module_.PushType(spv::Op::OpVariable, operands);
+ break;
+ }
+ default: {
+ TINT_ICE(Writer, diagnostics_)
+ << "unimplemented variable address space " << ptr->AddressSpace();
+ return 0u;
+ }
}
// Set the name if present.
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 3c8783e..5793325 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -76,6 +76,11 @@
/// @returns the result ID of the constant
uint32_t Constant(const ir::Constant* constant);
+ /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
+ /// @param type the type to get the ID for
+ /// @returns the result ID of the OpConstantNull instruction
+ uint32_t ConstantNull(const type::Type* type);
+
/// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
/// @param ty the type to get the ID for
/// @returns the result ID of the type
@@ -104,6 +109,10 @@
/// @param block the block to emit
void EmitBlock(const ir::Block* block);
+ /// Emit the root block.
+ /// @param root_block the root block to emit
+ void EmitRootBlock(const ir::Block* root_block);
+
/// Emit an `if` flow node.
/// @param i the if node to emit
void EmitIf(const ir::If* i);
@@ -194,6 +203,9 @@
/// The map of constants to their result IDs.
utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
+ /// The map of types to the result IDs of their OpConstantNull instructions.
+ utils::Hashmap<const type::Type*, uint32_t, 4> constant_nulls_;
+
/// The map of non-constant values to their result IDs.
utils::Hashmap<const ir::Value*, uint32_t, 8> values_;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index f43a09a..1c9b9ac 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -142,5 +142,160 @@
)");
}
+TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param = b.BlockParam(b.ir.Types().i32());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ i->True()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i)})});
+ i->False()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i)})});
+ i->Merge()->SetParams(utils::Vector{merge_param});
+ i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+
+ func->StartTarget()->SetInstructions(utils::Vector{i});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%11 = OpTypeInt 32 1
+%12 = OpConstant %11 10
+%13 = OpConstant %11 20
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpBranch %5
+%7 = OpLabel
+OpBranch %5
+%5 = OpLabel
+%10 = OpPhi %11 %12 %6 %13 %7
+OpReturnValue %10
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_TrueReturn) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param = b.BlockParam(b.ir.Types().i32());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ i->True()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{b.Constant(42_i)})});
+ i->False()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i)})});
+ i->Merge()->SetParams(utils::Vector{merge_param});
+ i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+
+ func->StartTarget()->SetInstructions(utils::Vector{i});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%11 = OpTypeInt 32 1
+%10 = OpConstant %11 42
+%13 = OpConstant %11 20
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpReturnValue %10
+%7 = OpLabel
+OpBranch %5
+%5 = OpLabel
+%12 = OpPhi %11 %13 %7
+OpReturnValue %12
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_FalseReturn) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param = b.BlockParam(b.ir.Types().i32());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ i->True()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i)})});
+ i->False()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{b.Constant(42_i)})});
+ i->Merge()->SetParams(utils::Vector{merge_param});
+ i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+
+ func->StartTarget()->SetInstructions(utils::Vector{i});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%11 = OpTypeInt 32 1
+%10 = OpConstant %11 42
+%13 = OpConstant %11 10
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpBranch %5
+%7 = OpLabel
+OpReturnValue %10
+%5 = OpLabel
+%12 = OpPhi %11 %13 %6
+OpReturnValue %12
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, If_Phi_MultipleValue) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
+ auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
+
+ auto* i = b.CreateIf(b.Constant(true));
+ i->True()->SetInstructions(
+ utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i), b.Constant(true)})});
+ i->False()->SetInstructions(
+ utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i), b.Constant(false)})});
+ i->Merge()->SetParams(utils::Vector{merge_param_0, merge_param_1});
+ i->Merge()->SetInstructions(utils::Vector{
+ b.Return(func, utils::Vector{merge_param_0}),
+ });
+
+ func->StartTarget()->SetInstructions(utils::Vector{i});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%9 = OpTypeBool
+%8 = OpConstantTrue %9
+%11 = OpTypeInt 32 1
+%12 = OpConstant %11 10
+%13 = OpConstant %11 20
+%15 = OpConstantFalse %9
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %8 %6 %7
+%6 = OpLabel
+OpBranch %5
+%7 = OpLabel
+OpBranch %5
+%5 = OpLabel
+%10 = OpPhi %11 %12 %6 %13 %7
+%14 = OpPhi %9 %8 %6 %15 %7
+OpReturnValue %10
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index e16d116..a32a751 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -24,7 +24,7 @@
auto* loop = b.CreateLoop();
- loop->Start()->AddInstruction(b.Continue(loop));
+ loop->Body()->AddInstruction(b.Continue(loop));
loop->Continuing()->AddInstruction(b.BreakIf(b.Constant(true), loop));
loop->Merge()->AddInstruction(b.Return(func));
@@ -58,7 +58,7 @@
auto* loop = b.CreateLoop();
- loop->Start()->AddInstruction(b.ExitLoop(loop));
+ loop->Body()->AddInstruction(b.ExitLoop(loop));
loop->Merge()->AddInstruction(b.Return(func));
func->StartTarget()->AddInstruction(loop);
@@ -93,7 +93,7 @@
cond_break->False()->AddInstruction(b.ExitIf(cond_break));
cond_break->Merge()->AddInstruction(b.Continue(loop));
- loop->Start()->AddInstruction(cond_break);
+ loop->Body()->AddInstruction(cond_break);
loop->Continuing()->AddInstruction(b.NextIteration(loop));
loop->Merge()->AddInstruction(b.Return(func));
@@ -136,7 +136,7 @@
cond_break->False()->AddInstruction(b.ExitIf(cond_break));
cond_break->Merge()->AddInstruction(b.ExitLoop(loop));
- loop->Start()->AddInstruction(cond_break);
+ loop->Body()->AddInstruction(cond_break);
loop->Continuing()->AddInstruction(b.NextIteration(loop));
loop->Merge()->AddInstruction(b.Return(func));
@@ -176,7 +176,7 @@
auto* loop = b.CreateLoop();
- loop->Start()->AddInstruction(b.Return(func));
+ loop->Body()->AddInstruction(b.Return(func));
func->StartTarget()->AddInstruction(loop);
@@ -207,7 +207,7 @@
auto* result = b.Equal(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));
- loop->Start()->AddInstruction(result);
+ loop->Body()->AddInstruction(result);
loop->Continuing()->AddInstruction(b.BreakIf(result, loop));
loop->Merge()->AddInstruction(b.Return(func));
@@ -242,11 +242,11 @@
auto* outer_loop = b.CreateLoop();
auto* inner_loop = b.CreateLoop();
- inner_loop->Start()->AddInstruction(b.ExitLoop(inner_loop));
+ inner_loop->Body()->AddInstruction(b.ExitLoop(inner_loop));
inner_loop->Continuing()->AddInstruction(b.NextIteration(inner_loop));
inner_loop->Merge()->AddInstruction(b.Continue(outer_loop));
- outer_loop->Start()->AddInstruction(inner_loop);
+ outer_loop->Body()->AddInstruction(inner_loop);
outer_loop->Continuing()->AddInstruction(b.BreakIf(b.Constant(true), outer_loop));
outer_loop->Merge()->AddInstruction(b.Return(func));
@@ -289,11 +289,11 @@
auto* outer_loop = b.CreateLoop();
auto* inner_loop = b.CreateLoop();
- inner_loop->Start()->AddInstruction(b.Continue(inner_loop));
+ inner_loop->Body()->AddInstruction(b.Continue(inner_loop));
inner_loop->Continuing()->AddInstruction(b.BreakIf(b.Constant(true), inner_loop));
inner_loop->Merge()->AddInstruction(b.BreakIf(b.Constant(true), outer_loop));
- outer_loop->Start()->AddInstruction(b.Continue(outer_loop));
+ outer_loop->Body()->AddInstruction(b.Continue(outer_loop));
outer_loop->Continuing()->AddInstruction(inner_loop);
outer_loop->Merge()->AddInstruction(b.Return(func));
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index 92800a3..1c4d5d0 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -217,5 +217,131 @@
)");
}
+TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param = b.BlockParam(b.ir.Types().i32());
+
+ auto* s = b.CreateSwitch(b.Constant(42_i));
+ auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->AddInstruction(b.ExitSwitch(s, utils::Vector{b.Constant(10_i)}));
+
+ auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->AddInstruction(b.ExitSwitch(s, utils::Vector{b.Constant(20_i)}));
+
+ s->Merge()->SetParams(utils::Vector{merge_param});
+ s->Merge()->AddInstruction(b.Return(func));
+
+ func->StartTarget()->SetInstructions(utils::Vector{s});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%11 = OpConstant %7 10
+%12 = OpConstant %7 20
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %6 %5 1 %5 2 %8
+%5 = OpLabel
+OpBranch %9
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+%10 = OpPhi %7 %11 %5 %12 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue_CaseReturn) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* s = b.CreateSwitch(b.Constant(42_i));
+ auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->AddInstruction(b.Return(func, utils::Vector{b.Constant(10_i)}));
+
+ auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->AddInstruction(b.ExitSwitch(s, utils::Vector{b.Constant(20_i)}));
+
+ s->Merge()->SetParams(utils::Vector{b.BlockParam(b.ir.Types().i32())});
+ s->Merge()->AddInstruction(b.Return(func));
+
+ func->StartTarget()->SetInstructions(utils::Vector{s});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%10 = OpConstant %7 10
+%12 = OpConstant %7 20
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %6 %5 1 %5 2 %8
+%5 = OpLabel
+OpReturnValue %10
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+%11 = OpPhi %7 %12 %8
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_Phi_MultipleValue) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
+ auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
+
+ auto* s = b.CreateSwitch(b.Constant(42_i));
+ auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->AddInstruction(b.ExitSwitch(s, utils::Vector{b.Constant(10_i), b.Constant(true)}));
+
+ auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->AddInstruction(b.ExitSwitch(s, utils::Vector{b.Constant(20_i), b.Constant(false)}));
+
+ s->Merge()->SetParams(utils::Vector{merge_param_0, merge_param_1});
+ s->Merge()->AddInstruction(b.Return(func, utils::Vector{merge_param_0}));
+
+ func->StartTarget()->SetInstructions(utils::Vector{s});
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%11 = OpConstant %7 10
+%12 = OpConstant %7 20
+%14 = OpTypeBool
+%15 = OpConstantTrue %14
+%16 = OpConstantFalse %14
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %6 %5 1 %5 2 %8
+%5 = OpLabel
+OpBranch %9
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+%10 = OpPhi %7 %11 %5 %12 %8
+%13 = OpPhi %14 %15 %5 %16 %8
+OpReturnValue %10
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index 80e69c3..7b2142c 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -125,6 +125,50 @@
"%1 = OpTypeMatrix %2 4\n");
}
+TEST_F(SpvGeneratorImplTest, Type_Array_DefaultStride) {
+ auto* arr = mod.Types().array(mod.Types().f32(), 4u);
+ auto id = generator_.Type(arr);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%2 = OpTypeFloat 32\n"
+ "%4 = OpTypeInt 32 0\n"
+ "%3 = OpConstant %4 4\n"
+ "%1 = OpTypeArray %2 %3\n");
+ EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 4\n");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_Array_ExplicitStride) {
+ auto* arr = mod.Types().array(mod.Types().f32(), 4u, 16);
+ auto id = generator_.Type(arr);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%2 = OpTypeFloat 32\n"
+ "%4 = OpTypeInt 32 0\n"
+ "%3 = OpConstant %4 4\n"
+ "%1 = OpTypeArray %2 %3\n");
+ EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 16\n");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_DefaultStride) {
+ auto* arr = mod.Types().runtime_array(mod.Types().f32());
+ auto id = generator_.Type(arr);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%2 = OpTypeFloat 32\n"
+ "%1 = OpTypeRuntimeArray %2\n");
+ EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 4\n");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_ExplicitStride) {
+ auto* arr = mod.Types().runtime_array(mod.Types().f32(), 16);
+ auto id = generator_.Type(arr);
+ EXPECT_EQ(id, 1u);
+ EXPECT_EQ(DumpTypes(),
+ "%2 = OpTypeFloat 32\n"
+ "%1 = OpTypeRuntimeArray %2\n");
+ EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 16\n");
+}
+
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpvGeneratorImplTest, Type_Multiple) {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 574a0b0..8ad2c10 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -180,5 +180,235 @@
)");
}
+TEST_F(SpvGeneratorImplTest, PrivateVar_NoInit) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
+ builtin::Access::kReadWrite);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, PrivateVar_WithInit) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ v->SetInitializer(b.Constant(42_i));
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpName %5 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstant %3 42
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, PrivateVar_Name) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kPrivate,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ v->SetInitializer(b.Constant(42_i));
+ mod.SetName(v, "myvar");
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpName %1 "myvar"
+OpName %5 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstant %3 42
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, PrivateVar_LoadAndStore) {
+ auto* func =
+ b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kFragment);
+ mod.functions.Push(func);
+
+ auto* store_ty = mod.Types().i32();
+ auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kPrivate,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ v->SetInitializer(b.Constant(42_i));
+
+ auto* load = b.Load(v);
+ auto* add = b.Add(store_ty, v, b.Constant(1_i));
+ auto* store = b.Store(v, add);
+ func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %5 "foo"
+OpExecutionMode %5 OriginUpperLeft
+OpName %5 "foo"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstant %3 42
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%11 = OpConstant %3 1
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+%9 = OpLoad %3 %1
+%10 = OpIAdd %3 %1 %11
+OpStore %1 %10
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ mod.SetName(v, "myvar");
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %1 "myvar"
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
+ mod.functions.Push(func);
+
+ auto* store_ty = mod.Types().i32();
+ auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ auto* load = b.Load(v);
+ auto* add = b.Add(store_ty, v, b.Constant(1_i));
+ auto* store = b.Store(v, add);
+ func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "foo"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "foo"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%10 = OpConstant %3 1
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+%8 = OpLoad %3 %1
+%9 = OpIAdd %3 %1 %10
+OpStore %1 %9
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+ // Create a generator with the zero_init_workgroup_memory flag set to `true`.
+ spirv::GeneratorImplIr gen(&mod, true);
+ gen.Generate();
+ EXPECT_EQ(DumpModule(gen.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpName %5 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Workgroup %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv