Import Tint changes from Dawn
Changes:
- edd2305fe411613e031da2232fe7f8c4ca9acef1 [msl] Validate generated MSL in writer unit tests by James Price <jrprice@google.com>
- 1d948c6e10c162080c482cff184b8c13309bf2dc [msl] Print functions in dependency order by James Price <jrprice@google.com>
- 35944abcf8f589a23331b6a8e4d4abcccb910bf0 [ir] Add Module::DependencyOrderedFunctions() by James Price <jrprice@google.com>
- 7a0d4c5da04921814fce4ce8cb25e276bce9a84d [msl] const-qualify many things in printer by James Price <jrprice@google.com>
- aa39095a577b9940c50be63b64047a860696854a [tint][ast] Fix DirectVariableAccess with uncalled functi... by Ben Clayton <bclayton@google.com>
- f03cc9432edbd5647f46fbfdad86335f8e5ed117 Fixup doxygen issues by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: edd2305fe411613e031da2232fe7f8c4ca9acef1
Change-Id: I1a59b8f6d5715a375609b3c63f6d22e3fdbcab83
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/188400
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/control_instruction.h b/src/tint/lang/core/ir/control_instruction.h
index 0f98763..60f29e2 100644
--- a/src/tint/lang/core/ir/control_instruction.h
+++ b/src/tint/lang/core/ir/control_instruction.h
@@ -54,6 +54,10 @@
/// @param cb the function to call once for each block
virtual void ForeachBlock(const std::function<void(ir::Block*)>& cb) = 0;
+ /// Calls @p cb for each block owned by this control instruction
+ /// @param cb the function to call once for each block
+ virtual void ForeachBlock(const std::function<void(const ir::Block*)>& cb) const = 0;
+
/// @return All the exits for the flow control instruction
const Hashset<Exit*, 2>& Exits() const { return exits_; }
diff --git a/src/tint/lang/core/ir/if.cc b/src/tint/lang/core/ir/if.cc
index be25a7c..c7a2aad 100644
--- a/src/tint/lang/core/ir/if.cc
+++ b/src/tint/lang/core/ir/if.cc
@@ -63,6 +63,15 @@
}
}
+void If::ForeachBlock(const std::function<void(const ir::Block*)>& cb) const {
+ if (true_) {
+ cb(true_);
+ }
+ if (false_) {
+ cb(false_);
+ }
+}
+
If* If::Clone(CloneContext& ctx) {
auto* cond = ctx.Remap(Condition());
auto* new_true = ctx.ir.blocks.Create<ir::Block>();
diff --git a/src/tint/lang/core/ir/if.h b/src/tint/lang/core/ir/if.h
index 7dfe6d2..3a6e0d7 100644
--- a/src/tint/lang/core/ir/if.h
+++ b/src/tint/lang/core/ir/if.h
@@ -77,6 +77,9 @@
/// @copydoc ControlInstruction::ForeachBlock
void ForeachBlock(const std::function<void(ir::Block*)>& cb) override;
+ /// @copydoc ControlInstruction::ForeachBlock
+ void ForeachBlock(const std::function<void(const ir::Block*)>& cb) const override;
+
/// @returns the if condition
Value* Condition() { return operands_[kConditionOperandOffset]; }
diff --git a/src/tint/lang/core/ir/loop.cc b/src/tint/lang/core/ir/loop.cc
index a2c8e83..fca440f 100644
--- a/src/tint/lang/core/ir/loop.cc
+++ b/src/tint/lang/core/ir/loop.cc
@@ -87,6 +87,18 @@
}
}
+void Loop::ForeachBlock(const std::function<void(const ir::Block*)>& cb) const {
+ if (initializer_) {
+ cb(initializer_);
+ }
+ if (body_) {
+ cb(body_);
+ }
+ if (continuing_) {
+ cb(continuing_);
+ }
+}
+
bool Loop::HasInitializer() const {
return initializer_->Terminator() != nullptr;
}
diff --git a/src/tint/lang/core/ir/loop.h b/src/tint/lang/core/ir/loop.h
index c87760e..edfca17 100644
--- a/src/tint/lang/core/ir/loop.h
+++ b/src/tint/lang/core/ir/loop.h
@@ -88,6 +88,9 @@
/// @copydoc ControlInstruction::ForeachBlock
void ForeachBlock(const std::function<void(ir::Block*)>& cb) override;
+ /// @copydoc ControlInstruction::ForeachBlock
+ void ForeachBlock(const std::function<void(const ir::Block*)>& cb) const override;
+
/// @returns the switch initializer block
ir::Block* Initializer() { return initializer_; }
diff --git a/src/tint/lang/core/ir/module.cc b/src/tint/lang/core/ir/module.cc
index cbc747b..a1be474 100644
--- a/src/tint/lang/core/ir/module.cc
+++ b/src/tint/lang/core/ir/module.cc
@@ -28,11 +28,76 @@
#include "src/tint/lang/core/ir/module.h"
#include <limits>
+#include <utility>
+#include "src/tint/lang/core/ir/control_instruction.h"
+#include "src/tint/lang/core/ir/user_call.h"
+#include "src/tint/utils/containers/unique_vector.h"
#include "src/tint/utils/ice/ice.h"
namespace tint::core::ir {
+namespace {
+
+/// Helper to non-recursively sort a module's function in dependency order.
+struct FunctionSorter {
+ /// The dependency-ordered list of functions.
+ Vector<const Function*, 16> ordered_functions{};
+
+ /// The functions that have been visited and checked for dependencies.
+ Hashset<const Function*, 16> visited{};
+ /// A stack of functions that need to processed and eventually added to the ordered list.
+ Vector<const Function*, 16> function_stack{};
+
+ /// Visit a function and check for dependencies, and eventually add it to the ordered list.
+ /// @param func the function to visit
+ void Visit(const Function* func) {
+ function_stack.Push(func);
+ while (!function_stack.IsEmpty()) {
+ // Visit the next function on the stack, if it hasn't already been visited.
+ auto* current_function = function_stack.Back();
+ if (visited.Add(current_function)) {
+ // Check for dependencies inside the function, adding them to the queue if they have
+ // not already been visited.
+ Visit(current_function->Block());
+ } else {
+ // We previously visited the function, so just discard it.
+ function_stack.Pop();
+ }
+
+ // If the function at the top of the stack has been visited, we know that it has no
+ // unvisited dependencies. We can now add it to the ordered list, and walk back down the
+ // stack until we find the next unvisited function.
+ while (!function_stack.IsEmpty() && visited.Contains(function_stack.Back())) {
+ ordered_functions.Push(function_stack.Pop());
+ }
+ }
+ }
+
+ /// Visit a function body block and look for dependencies.
+ /// @param block the function body to visit
+ void Visit(const Block* block) {
+ Vector<const Block*, 64> block_stack;
+ block_stack.Push(block);
+ while (!block_stack.IsEmpty()) {
+ auto* current_block = block_stack.Pop();
+ for (auto* inst : *current_block) {
+ if (auto* control = inst->As<ControlInstruction>()) {
+ // Enqueue child blocks.
+ control->ForeachBlock([&](const Block* b) { block_stack.Push(b); });
+ } else if (auto* call = inst->As<UserCall>()) {
+ // Enqueue the function that is being called.
+ if (!visited.Contains(call->Target())) {
+ function_stack.Push(call->Target());
+ }
+ }
+ }
+ }
+ }
+};
+
+} // namespace
+
Module::Module() : root_block(blocks.Create<ir::Block>()) {}
Module::Module(Module&&) = default;
@@ -71,4 +136,12 @@
value_to_name_.Remove(value);
}
+Vector<const Function*, 16> Module::DependencyOrderedFunctions() const {
+ FunctionSorter sorter;
+ for (auto& func : functions) {
+ sorter.Visit(func);
+ }
+ return std::move(sorter.ordered_functions);
+}
+
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/module.h b/src/tint/lang/core/ir/module.h
index f4a93a1..ebbd0c0 100644
--- a/src/tint/lang/core/ir/module.h
+++ b/src/tint/lang/core/ir/module.h
@@ -129,6 +129,9 @@
return {allocators.values.Objects()};
}
+ /// @returns the functions in the module, in dependency order
+ Vector<const Function*, 16> DependencyOrderedFunctions() const;
+
/// The block allocator
BlockAllocator<Block> blocks;
@@ -144,7 +147,7 @@
BlockAllocator<Value> values;
} allocators;
- /// List of functions in the program
+ /// List of functions in the module.
Vector<ConstPropagatingPtr<Function>, 8> functions;
/// The block containing module level declarations, if any exist.
diff --git a/src/tint/lang/core/ir/module_test.cc b/src/tint/lang/core/ir/module_test.cc
index 6cfa913..b04ae70 100644
--- a/src/tint/lang/core/ir/module_test.cc
+++ b/src/tint/lang/core/ir/module_test.cc
@@ -26,9 +26,13 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/core/ir/module.h"
+
+#include "gmock/gmock.h"
#include "src/tint/lang/core/ir/ir_helper_test.h"
#include "src/tint/lang/core/ir/var.h"
+using ::testing::ElementsAre;
+
namespace tint::core::ir {
namespace {
@@ -55,5 +59,42 @@
EXPECT_EQ(mod.NameOf(v).Name(), "b");
}
+TEST_F(IR_ModuleTest, DependencyOrderedFunctions) {
+ auto* fa = b.Function("a", ty.void_());
+ auto* fb = b.Function("b", ty.void_());
+ auto* fc = b.Function("c", ty.void_());
+ auto* fd = b.Function("d", ty.void_());
+ b.Append(fa->Block(), [&] { //
+ auto* ifelse = b.If(true);
+ b.Append(ifelse->True(), [&] {
+ b.Call(fc);
+ b.ExitIf(ifelse);
+ });
+ b.Append(ifelse->False(), [&] {
+ b.Call(fb);
+ b.ExitIf(ifelse);
+ });
+ b.Return(fa);
+ });
+ b.Append(fb->Block(), [&] { //
+ b.Call(fc);
+ b.Call(fd);
+ b.Return(fb);
+ });
+ b.Append(fc->Block(), [&] { //
+ b.Call(fd);
+ b.Return(fc);
+ });
+ b.Append(fd->Block(), [&] { //
+ b.Return(fd);
+ });
+ mod.functions.Clear();
+ mod.functions.Push(fa);
+ mod.functions.Push(fd);
+ mod.functions.Push(fb);
+ mod.functions.Push(fc);
+ EXPECT_THAT(mod.DependencyOrderedFunctions(), ElementsAre(fd, fc, fb, fa));
+}
+
} // namespace
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/switch.cc b/src/tint/lang/core/ir/switch.cc
index a124daa..5ba9abf 100644
--- a/src/tint/lang/core/ir/switch.cc
+++ b/src/tint/lang/core/ir/switch.cc
@@ -53,6 +53,12 @@
}
}
+void Switch::ForeachBlock(const std::function<void(const ir::Block*)>& cb) const {
+ for (auto& c : cases_) {
+ cb(c.block);
+ }
+}
+
Switch* Switch::Clone(CloneContext& ctx) {
auto* cond = ctx.Remap(Condition());
auto* new_switch = ctx.ir.allocators.instructions.Create<Switch>(cond);
diff --git a/src/tint/lang/core/ir/switch.h b/src/tint/lang/core/ir/switch.h
index d0783b8..c1b5f3a 100644
--- a/src/tint/lang/core/ir/switch.h
+++ b/src/tint/lang/core/ir/switch.h
@@ -95,6 +95,9 @@
/// @copydoc ControlInstruction::ForeachBlock
void ForeachBlock(const std::function<void(ir::Block*)>& cb) override;
+ /// @copydoc ControlInstruction::ForeachBlock
+ void ForeachBlock(const std::function<void(const ir::Block*)>& cb) const override;
+
/// @returns the switch cases
Vector<Case, 4>& Cases() { return cases_; }
diff --git a/src/tint/lang/glsl/writer/ast_printer/ast_printer.h b/src/tint/lang/glsl/writer/ast_printer/ast_printer.h
index 808ff39..a1c41ae 100644
--- a/src/tint/lang/glsl/writer/ast_printer/ast_printer.h
+++ b/src/tint/lang/glsl/writer/ast_printer/ast_printer.h
@@ -70,7 +70,7 @@
};
/// Sanitize a program in preparation for generating GLSL.
-/// @program The program to sanitize
+/// @param program The program to sanitize
/// @param options The HLSL generator options.
/// @param entry_point the entry point to generate GLSL for
/// @returns the sanitized program and any supplementary information
diff --git a/src/tint/lang/glsl/writer/ast_printer/helper_test.h b/src/tint/lang/glsl/writer/ast_printer/helper_test.h
index f7ccd9c..dc1972d 100644
--- a/src/tint/lang/glsl/writer/ast_printer/helper_test.h
+++ b/src/tint/lang/glsl/writer/ast_printer/helper_test.h
@@ -127,8 +127,11 @@
private:
std::unique_ptr<ASTPrinter> gen_;
};
+
+/// Test class
using TestHelper = TestHelperBase<testing::Test>;
+/// Test param class
template <typename T>
using TestParamHelper = TestHelperBase<testing::TestWithParam<T>>;
diff --git a/src/tint/lang/glsl/writer/printer/helper_test.h b/src/tint/lang/glsl/writer/printer/helper_test.h
index 7550ecb..2dbcd00 100644
--- a/src/tint/lang/glsl/writer/printer/helper_test.h
+++ b/src/tint/lang/glsl/writer/printer/helper_test.h
@@ -89,8 +89,10 @@
}
};
+/// Test class
using GlslPrinterTest = GlslPrinterTestHelperBase<testing::Test>;
+/// Test param class
template <typename T>
using GlslPrinterTestWithParam = GlslPrinterTestHelperBase<testing::TestWithParam<T>>;
diff --git a/src/tint/lang/msl/writer/printer/BUILD.bazel b/src/tint/lang/msl/writer/printer/BUILD.bazel
index ef25cfe..ad3ce48 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.bazel
+++ b/src/tint/lang/msl/writer/printer/BUILD.bazel
@@ -116,6 +116,7 @@
"@gtest",
] + select({
":tint_build_msl_writer": [
+ "//src/tint/lang/msl/validate",
"//src/tint/lang/msl/writer/common",
"//src/tint/lang/msl/writer/printer",
"//src/tint/lang/msl/writer/raise",
diff --git a/src/tint/lang/msl/writer/printer/BUILD.cmake b/src/tint/lang/msl/writer/printer/BUILD.cmake
index b922ea6..8ae3d60 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.cmake
+++ b/src/tint/lang/msl/writer/printer/BUILD.cmake
@@ -126,6 +126,7 @@
if(TINT_BUILD_MSL_WRITER)
tint_target_add_dependencies(tint_lang_msl_writer_printer_test test
+ tint_lang_msl_validate
tint_lang_msl_writer_common
tint_lang_msl_writer_printer
tint_lang_msl_writer_raise
diff --git a/src/tint/lang/msl/writer/printer/BUILD.gn b/src/tint/lang/msl/writer/printer/BUILD.gn
index 3db1a8f..ec57337 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.gn
+++ b/src/tint/lang/msl/writer/printer/BUILD.gn
@@ -119,6 +119,7 @@
if (tint_build_msl_writer) {
deps += [
+ "${tint_src_dir}/lang/msl/validate",
"${tint_src_dir}/lang/msl/writer/common",
"${tint_src_dir}/lang/msl/writer/printer",
"${tint_src_dir}/lang/msl/writer/raise",
diff --git a/src/tint/lang/msl/writer/printer/binary_test.cc b/src/tint/lang/msl/writer/printer/binary_test.cc
index 495b156..fe0a464 100644
--- a/src/tint/lang/msl/writer/printer/binary_test.cc
+++ b/src/tint/lang/msl/writer/printer/binary_test.cc
@@ -91,14 +91,14 @@
ASSERT_TRUE(Generate()) << err_ << output_;
EXPECT_EQ(output_, MetalHeader() + R"(
+uint tint_div_u32(uint lhs, uint rhs) {
+ return (lhs / select(rhs, 1u, (rhs == 0u)));
+}
void foo() {
uint const left = 1u;
uint const right = 2u;
uint const val = tint_div_u32(left, right);
}
-uint tint_div_u32(uint lhs, uint rhs) {
- return (lhs / select(rhs, 1u, (rhs == 0u)));
-}
)");
}
@@ -114,15 +114,15 @@
ASSERT_TRUE(Generate()) << err_ << output_;
EXPECT_EQ(output_, MetalHeader() + R"(
+uint tint_mod_u32(uint lhs, uint rhs) {
+ uint const v = select(rhs, 1u, (rhs == 0u));
+ return (lhs - ((lhs / v) * v));
+}
void foo() {
uint const left = 1u;
uint const right = 2u;
uint const val = tint_mod_u32(left, right);
}
-uint tint_mod_u32(uint lhs, uint rhs) {
- uint const v = select(rhs, 1u, (rhs == 0u));
- return (lhs - ((lhs / v) * v));
-}
)");
}
diff --git a/src/tint/lang/msl/writer/printer/discard_test.cc b/src/tint/lang/msl/writer/printer/discard_test.cc
index 64451a3..d0ccbb5 100644
--- a/src/tint/lang/msl/writer/printer/discard_test.cc
+++ b/src/tint/lang/msl/writer/printer/discard_test.cc
@@ -32,7 +32,9 @@
namespace tint::msl::writer {
namespace {
-TEST_F(MslPrinterTest, Discard) {
+// TODO(jrprice): Disabled as DemoteToHelper introduces module-scope variables, which are not
+// handled correctly yet.
+TEST_F(MslPrinterTest, DISABLED_Discard) {
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* if_ = b.If(true);
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index 2af6bb2..bef8b61 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -34,6 +34,7 @@
#include "gtest/gtest.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/msl/validate/validate.h"
#include "src/tint/lang/msl/writer/printer/printer.h"
#include "src/tint/lang/msl/writer/raise/raise.h"
@@ -92,6 +93,14 @@
}
output_ = result.Get();
+#if TINT_BUILD_IS_MAC
+ auto msl_validation = validate::ValidateUsingMetal(output_, validate::MslVersion::kMsl_2_3);
+ if (msl_validation.failed) {
+ err_ = msl_validation.output;
+ return false;
+ }
+#endif
+
return true;
}
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index c4fadc3..40ffcdf 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -122,7 +122,7 @@
EmitBlockInstructions(ir_.root_block);
// Emit functions.
- for (auto& func : ir_.functions) {
+ for (auto* func : ir_.DependencyOrderedFunctions()) {
EmitFunction(func);
}
@@ -150,9 +150,9 @@
std::unordered_set<const core::type::Struct*> emitted_structs_;
/// The current function being emitted
- core::ir::Function* current_function_ = nullptr;
+ const core::ir::Function* current_function_ = nullptr;
/// The current block being emitted
- core::ir::Block* current_block_ = nullptr;
+ const core::ir::Block* current_block_ = nullptr;
/// Unique name of the tint_array<T, N> template.
/// Non-empty only if the template has been generated.
@@ -224,7 +224,7 @@
/// Emit the function
/// @param func the function to emit
- void EmitFunction(core::ir::Function* func) {
+ void EmitFunction(const core::ir::Function* func) {
TINT_SCOPED_ASSIGNMENT(current_function_, func);
{
@@ -318,43 +318,43 @@
/// Emit a block
/// @param block the block to emit
- void EmitBlock(core::ir::Block* block) { EmitBlockInstructions(block); }
+ void EmitBlock(const core::ir::Block* block) { EmitBlockInstructions(block); }
/// Emit the instructions in a block
/// @param block the block with the instructions to emit
- void EmitBlockInstructions(core::ir::Block* block) {
+ void EmitBlockInstructions(const core::ir::Block* block) {
TINT_SCOPED_ASSIGNMENT(current_block_, block);
for (auto* inst : *block) {
Switch(
- inst, //
- [&](core::ir::BreakIf* i) { EmitBreakIf(i); }, //
- [&](core::ir::Continue*) { EmitContinue(); }, //
- [&](core::ir::Discard*) { EmitDiscard(); }, //
- [&](core::ir::ExitIf* i) { EmitExitIf(i); }, //
- [&](core::ir::ExitLoop*) { EmitExitLoop(); }, //
- [&](core::ir::ExitSwitch*) { EmitExitSwitch(); }, //
- [&](core::ir::If* i) { EmitIf(i); }, //
- [&](core::ir::Let* i) { EmitLet(i); }, //
- [&](core::ir::Loop* i) { EmitLoop(i); }, //
- [&](core::ir::NextIteration*) { /* do nothing */ }, //
- [&](core::ir::Return* i) { EmitReturn(i); }, //
- [&](core::ir::Store* i) { EmitStore(i); }, //
- [&](core::ir::Switch* i) { EmitSwitch(i); }, //
- [&](core::ir::Unreachable*) { EmitUnreachable(); }, //
- [&](core::ir::Call* i) { EmitCallStmt(i); }, //
- [&](core::ir::Var* i) { EmitVar(i); }, //
- [&](core::ir::StoreVectorElement* e) { EmitStoreVectorElement(e); },
- [&](core::ir::TerminateInvocation*) { EmitDiscard(); }, //
+ inst, //
+ [&](const core::ir::BreakIf* i) { EmitBreakIf(i); }, //
+ [&](const core::ir::Continue*) { EmitContinue(); }, //
+ [&](const core::ir::Discard*) { EmitDiscard(); }, //
+ [&](const core::ir::ExitIf* i) { EmitExitIf(i); }, //
+ [&](const core::ir::ExitLoop*) { EmitExitLoop(); }, //
+ [&](const core::ir::ExitSwitch*) { EmitExitSwitch(); }, //
+ [&](const core::ir::If* i) { EmitIf(i); }, //
+ [&](const core::ir::Let* i) { EmitLet(i); }, //
+ [&](const core::ir::Loop* i) { EmitLoop(i); }, //
+ [&](const core::ir::NextIteration*) { /* do nothing */ }, //
+ [&](const core::ir::Return* i) { EmitReturn(i); }, //
+ [&](const core::ir::Store* i) { EmitStore(i); }, //
+ [&](const core::ir::Switch* i) { EmitSwitch(i); }, //
+ [&](const core::ir::Unreachable*) { EmitUnreachable(); }, //
+ [&](const core::ir::Call* i) { EmitCallStmt(i); }, //
+ [&](const core::ir::Var* i) { EmitVar(i); }, //
+ [&](const core::ir::StoreVectorElement* e) { EmitStoreVectorElement(e); },
+ [&](const core::ir::TerminateInvocation*) { EmitDiscard(); }, //
- [&](core::ir::LoadVectorElement*) { /* inlined */ }, //
- [&](core::ir::Swizzle*) { /* inlined */ }, //
- [&](core::ir::Bitcast*) { /* inlined */ }, //
- [&](core::ir::CoreBinary*) { /* inlined */ }, //
- [&](core::ir::CoreUnary*) { /* inlined */ }, //
- [&](core::ir::Load*) { /* inlined */ }, //
- [&](core::ir::Construct*) { /* inlined */ }, //
- [&](core::ir::Access*) { /* inlined */ }, //
+ [&](const core::ir::LoadVectorElement*) { /* inlined */ }, //
+ [&](const core::ir::Swizzle*) { /* inlined */ }, //
+ [&](const core::ir::Bitcast*) { /* inlined */ }, //
+ [&](const core::ir::CoreBinary*) { /* inlined */ }, //
+ [&](const core::ir::CoreUnary*) { /* inlined */ }, //
+ [&](const core::ir::Load*) { /* inlined */ }, //
+ [&](const core::ir::Construct*) { /* inlined */ }, //
+ [&](const core::ir::Access*) { /* inlined */ }, //
TINT_ICE_ON_NO_MATCH);
}
}
@@ -479,7 +479,7 @@
/// Emit a var instruction
/// @param v the var instruction
- void EmitVar(core::ir::Var* v) {
+ void EmitVar(const core::ir::Var* v) {
auto out = Line();
auto* ptr = v->Result(0)->Type()->As<core::type::Pointer>();
@@ -517,7 +517,7 @@
/// Emit a let instruction
/// @param l the let instruction
- void EmitLet(core::ir::Let* l) {
+ void EmitLet(const core::ir::Let* l) {
auto out = Line();
EmitType(out, l->Result(0)->Type());
out << " const " << NameOf(l->Result(0)) << " = ";
@@ -527,7 +527,7 @@
void EmitExitLoop() { Line() << "break;"; }
- void EmitBreakIf(core::ir::BreakIf* b) {
+ void EmitBreakIf(const core::ir::BreakIf* b) {
auto out = Line();
out << "if ";
EmitValue(out, b->Condition());
@@ -541,7 +541,7 @@
Line() << "continue;";
}
- void EmitLoop(core::ir::Loop* l) {
+ void EmitLoop(const core::ir::Loop* l) {
// Note, we can't just emit the continuing inside a conditional at the top of the loop
// because any variable declared in the block must be visible to the continuing.
//
@@ -572,7 +572,7 @@
void EmitExitSwitch() { Line() << "break;"; }
- void EmitSwitch(core::ir::Switch* s) {
+ void EmitSwitch(const core::ir::Switch* s) {
{
auto out = Line();
out << "switch(";
@@ -646,7 +646,7 @@
/// Emit an if instruction
/// @param if_ the if instruction
- void EmitIf(core::ir::If* if_) {
+ void EmitIf(const core::ir::If* if_) {
{
auto out = Line();
out << "if (";
@@ -671,7 +671,7 @@
/// Emit an exit-if instruction
/// @param e the exit-if instruction
- void EmitExitIf(core::ir::ExitIf* e) {
+ void EmitExitIf(const core::ir::ExitIf* e) {
auto results = e->If()->Results();
auto args = e->Args();
for (size_t i = 0; i < e->Args().Length(); ++i) {
@@ -687,7 +687,7 @@
/// Emit a return instruction
/// @param r the return instruction
- void EmitReturn(core::ir::Return* r) {
+ void EmitReturn(const core::ir::Return* r) {
// If this return has no arguments and the current block is for the function which is
// being returned, skip the return.
if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
@@ -710,7 +710,7 @@
void EmitDiscard() { Line() << "discard_fragment();"; }
/// Emit a store
- void EmitStore(core::ir::Store* s) {
+ void EmitStore(const core::ir::Store* s) {
auto out = Line();
EmitValue(out, s->To());
diff --git a/src/tint/lang/msl/writer/printer/var_test.cc b/src/tint/lang/msl/writer/printer/var_test.cc
index 04f9358..6aa7227 100644
--- a/src/tint/lang/msl/writer/printer/var_test.cc
+++ b/src/tint/lang/msl/writer/printer/var_test.cc
@@ -247,7 +247,7 @@
)");
}
-// TODO(dsinclair): Requires ModuleScopeVarToEntryPointParam transform
+// TODO(jrprice): Requires ModuleScopeVarToEntryPointParam transform
TEST_F(MslPrinterTest, DISABLED_VarGlobalPrivate) {
core::ir::Var* v = nullptr;
b.Append(mod.root_block, [&] { v = b.Var("v", ty.ptr<core::AddressSpace::kPrivate, f32>()); });
@@ -274,7 +274,8 @@
)");
}
-TEST_F(MslPrinterTest, VarGlobalWorkgroup) {
+// TODO(jrprice): Requires ModuleScopeVarToEntryPointParam transform
+TEST_F(MslPrinterTest, DISABLED_VarGlobalWorkgroup) {
core::ir::Var* v = nullptr;
b.Append(mod.root_block,
[&] { v = b.Var("v", ty.ptr<core::AddressSpace::kWorkgroup, f32>()); });
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
index dac33c7..d9093ec 100644
--- a/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access.cc
@@ -834,9 +834,9 @@
if (auto incoming_shape = variant_sig.Get(param)) {
auto& symbols = *variant.ptr_param_symbols.Get(param);
if (symbols.base_ptr.IsValid()) {
- auto base_ptr_ty =
- b.ty.ptr(incoming_shape->root.address_space,
- CreateASTTypeFor(ctx, incoming_shape->root.type));
+ auto base_ptr_ty = b.ty.ptr(
+ incoming_shape->root.address_space,
+ CreateASTTypeFor(ctx, incoming_shape->root.type->UnwrapPtrOrRef()));
params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
}
if (symbols.indices.IsValid()) {
diff --git a/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc b/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc
index 04666b2..9b18a40 100644
--- a/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/direct_variable_access_test.cc
@@ -3209,6 +3209,33 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(DirectVariableAccessFunctionASTest, PointerForwarding_NoUse) {
+ auto* src = R"(
+fn a(p : ptr<function, i32>) -> i32 {
+ return *p;
+}
+
+fn b(p : ptr<function, i32>) -> i32 {
+ return a(p);
+}
+)";
+
+ auto* expect =
+ R"(
+fn a_F(p : ptr<function, i32>) -> i32 {
+ return *(p);
+}
+
+fn b(p : ptr<function, i32>) -> i32 {
+ return a_F(p);
+}
+)";
+
+ auto got = Run<DirectVariableAccess>(src, EnableFunction());
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace function_as_tests
////////////////////////////////////////////////////////////////////////////////