Import Tint changes from Dawn
Changes:
- 04e38e884b952543601bb05c85eb175f20eb0b8c [tint][resolver] Mark all short-circuited RHS expressions... by Ben Clayton <bclayton@google.com>
- dd89ce45f54704e4cdd147f295ba872b8af50a43 [glsl] Reorder GLSL transforms. by dan sinclair <dsinclair@chromium.org>
- f58479706392f005b0c5e85d0ccc688f2db8d7e1 [ir] Fix Disassembly::EmitOperandList() with count by James Price <jrprice@google.com>
- d42ff1655946e18869808cb7b4d3930da2d9b108 [ir] Track BreakIf as a loop exit by James Price <jrprice@google.com>
- a002503e66599d1c73760953a3952e64aff114df [ir] Validate loop body with params has initializer by James Price <jrprice@google.com>
- a7e91f5d02c9f5acad3104f91c4ad19f68a05c6a [spirv] Do not add exit phi for loop initializer by James Price <jrprice@google.com>
- 02cf91129de9237ec0dd58430b3fb6c2b971b5b0 [msl] Implement ShaderIO transform by James Price <jrprice@google.com>
- c509bb3757c02beecfd6d74ef7aae9ee94b1e3ed [tint][ir][spirv] Deduplicate storage textures with diffe... by Ben Clayton <bclayton@google.com>
- d82c5cbc6ad588606f4372647d45f062d13318c5 Tint: Add support for input_attachment_index in inspector. by Le Hoang Quyen <lehoangquyen@chromium.org>
- aefe7d2223a97b0b1580b4f77e2dff40b6c56406 OpenGLES: baseVertex, baseInstance workaround. by Stephen White <senorblanco@chromium.org>
GitOrigin-RevId: 04e38e884b952543601bb05c85eb175f20eb0b8c
Change-Id: I465e514bcd76bfc566a8717501f3a4f44ba28547
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/190860
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/cmd/common/helper.cc b/src/tint/cmd/common/helper.cc
index 0fd37f8..606dd69 100644
--- a/src/tint/cmd/common/helper.cc
+++ b/src/tint/cmd/common/helper.cc
@@ -498,6 +498,8 @@
return "DepthMultisampledTexture";
case tint::inspector::ResourceBinding::ResourceType::kExternalTexture:
return "ExternalTexture";
+ case tint::inspector::ResourceBinding::ResourceType::kInputAttachment:
+ return "InputAttachment";
}
return "Unknown";
diff --git a/src/tint/lang/core/constant/eval_binary_op_test.cc b/src/tint/lang/core/constant/eval_binary_op_test.cc
index df37060..2f47723 100644
--- a/src/tint/lang/core/constant/eval_binary_op_test.cc
+++ b/src/tint/lang/core/constant/eval_binary_op_test.cc
@@ -27,6 +27,7 @@
#include "src/tint/lang/core/constant/eval_test.h"
+#include "src/tint/lang/wgsl/builtin_fn.h"
#include "src/tint/utils/result/result.h"
#if TINT_BUILD_WGSL_READER
@@ -2420,9 +2421,10 @@
WrapInFunction(Decl(Var("b", Expr(false))), binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
- EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+ ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), false);
EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
- EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
}
TEST_F(ConstEvalTest, ShortCircuit_Or_RHSVarDecl) {
@@ -2434,9 +2436,40 @@
WrapInFunction(Decl(Var("b", Expr(false))), binary);
EXPECT_TRUE(r()->Resolve()) << r()->error();
- EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+ ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), true);
EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
- EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+}
+
+TEST_F(ConstEvalTest, ShortCircuit_And_RHSRuntimeBuiltin) {
+ // fn f() {
+ // var b = false;
+ // let result = false && any(b);
+ // }
+ auto* binary = LogicalAnd(false, Call(wgsl::BuiltinFn::kAny, "b"));
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), false);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+}
+
+TEST_F(ConstEvalTest, ShortCircuit_Or_RHSRuntimeBuiltin) {
+ // fn f() {
+ // var b = false;
+ // let result = true || any(b);
+ // }
+ auto* binary = LogicalOr(true, Call(wgsl::BuiltinFn::kAny, "b"));
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), true);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
}
////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/break_if.cc b/src/tint/lang/core/ir/break_if.cc
index efb2990..72ad55b 100644
--- a/src/tint/lang/core/ir/break_if.cc
+++ b/src/tint/lang/core/ir/break_if.cc
@@ -55,6 +55,7 @@
if (loop_) {
loop_->Body()->AddInboundSiblingBranch(this);
+ SetControlInstruction(loop_);
}
}
@@ -72,6 +73,7 @@
loop_->Body()->RemoveInboundSiblingBranch(this);
}
loop_ = loop;
+ SetControlInstruction(loop);
if (loop) {
loop->Body()->AddInboundSiblingBranch(this);
}
diff --git a/src/tint/lang/core/ir/break_if.h b/src/tint/lang/core/ir/break_if.h
index 38bf9d8..47533ec 100644
--- a/src/tint/lang/core/ir/break_if.h
+++ b/src/tint/lang/core/ir/break_if.h
@@ -30,7 +30,7 @@
#include <string>
-#include "src/tint/lang/core/ir/terminator.h"
+#include "src/tint/lang/core/ir/exit.h"
#include "src/tint/lang/core/ir/value.h"
#include "src/tint/utils/containers/const_propagating_ptr.h"
#include "src/tint/utils/rtti/castable.h"
@@ -42,8 +42,8 @@
namespace tint::core::ir {
-/// A break-if iteration instruction.
-class BreakIf final : public Castable<BreakIf, Terminator> {
+/// A break-if terminator instruction.
+class BreakIf final : public Castable<BreakIf, Exit> {
public:
/// The offset in Operands() for the condition
static constexpr size_t kConditionOperandOffset = 0;
diff --git a/src/tint/lang/core/ir/break_if_test.cc b/src/tint/lang/core/ir/break_if_test.cc
index d472c47..c32f00c 100644
--- a/src/tint/lang/core/ir/break_if_test.cc
+++ b/src/tint/lang/core/ir/break_if_test.cc
@@ -113,5 +113,20 @@
EXPECT_EQ(0u, args.Length());
}
+TEST_F(IR_BreakIfTest, SetLoop) {
+ auto* loop1 = b.Loop();
+ auto* loop2 = b.Loop();
+ auto* cond = b.Constant(true);
+ auto* arg1 = b.Constant(1_u);
+ auto* arg2 = b.Constant(2_u);
+
+ auto* brk = b.BreakIf(loop1, cond, arg1, arg2);
+ EXPECT_THAT(loop1->Exits(), testing::ElementsAre(brk));
+
+ brk->SetLoop(loop2);
+ EXPECT_TRUE(loop1->Exits().IsEmpty());
+ EXPECT_THAT(loop2->Exits(), testing::ElementsAre(brk));
+}
+
} // namespace
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/disassembly.cc b/src/tint/lang/core/ir/disassembly.cc
index bd48191..193b04e 100644
--- a/src/tint/lang/core/ir/disassembly.cc
+++ b/src/tint/lang/core/ir/disassembly.cc
@@ -576,7 +576,7 @@
}
void Disassembly::EmitOperandList(const Instruction* inst, size_t start_index, size_t count) {
- size_t n = std::min(count, inst->Operands().Length());
+ size_t n = std::min(start_index + count, inst->Operands().Length());
for (size_t i = start_index; i < n; i++) {
if (i != start_index) {
out_ << ", ";
diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc
index 9586785..a8fd7b0 100644
--- a/src/tint/lang/core/ir/transform/shader_io.cc
+++ b/src/tint/lang/core/ir/transform/shader_io.cc
@@ -86,7 +86,7 @@
}
auto new_params = backend->FinalizeInputs();
- auto* new_ret_val = backend->FinalizeOutputs();
+ auto* new_ret_ty = backend->FinalizeOutputs();
// Rename the old function and remove its pipeline stage and workgroup size, as we will be
// wrapping it with a new entry point.
@@ -98,7 +98,8 @@
func->ClearWorkgroupSize();
// Create the entry point wrapper function.
- auto* ep = b.Function(name, new_ret_val ? new_ret_val->Type() : ty.void_());
+ auto* ep = b.Function(name, new_ret_ty);
+ ep->SetParams(std::move(new_params));
ep->SetStage(stage);
if (wgsize) {
ep->SetWorkgroupSize((*wgsize)[0], (*wgsize)[1], (*wgsize)[2]);
@@ -114,7 +115,7 @@
}
// Return the new result.
- wrapper.Return(ep, new_ret_val);
+ wrapper.Return(ep, backend->MakeReturnValue(wrapper));
}
/// Gather the shader inputs.
diff --git a/src/tint/lang/core/ir/transform/shader_io.h b/src/tint/lang/core/ir/transform/shader_io.h
index 5d7afd3..106aa83 100644
--- a/src/tint/lang/core/ir/transform/shader_io.h
+++ b/src/tint/lang/core/ir/transform/shader_io.h
@@ -75,8 +75,8 @@
virtual Vector<FunctionParam*, 4> FinalizeInputs() = 0;
/// Finalize the shader outputs and create state needed for the new entry point function.
- /// @returns the return value for the new entry point
- virtual Value* FinalizeOutputs() = 0;
+ /// @returns the return type for the new entry point
+ virtual const type::Type* FinalizeOutputs() = 0;
/// Get the value of the input at index @p idx
/// @param builder the IR builder for new instructions
@@ -90,6 +90,11 @@
/// @param value the value to set
virtual void SetOutput(Builder& builder, uint32_t idx, Value* value) = 0;
+ /// Create the return value for the entry point, based on the output values that have been set.
+ /// @param builder the IR builder for new instructions
+ /// @returns the return value for the new entry point
+ virtual Value* MakeReturnValue([[maybe_unused]] Builder& builder) { return nullptr; }
+
/// @returns true if a vertex point size builtin should be added
virtual bool NeedsVertexPointSize() const { return false; }
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index f42eeb5..f24b8eb 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -362,6 +362,10 @@
/// @param l the loop to validate
void CheckLoop(const Loop* l);
+ /// Validates the loop body block
+ /// @param l the loop to validate
+ void CheckLoopBody(const Loop* l);
+
/// Validates the loop continuing block
/// @param l the loop to validate
void CheckLoopContinuing(const Loop* l);
@@ -1217,13 +1221,25 @@
});
}
- tasks_.Push([this, l] { BeginBlock(l->Body()); });
+ tasks_.Push([this, l] {
+ CheckLoopBody(l);
+ BeginBlock(l->Body());
+ });
if (!l->Initializer()->IsEmpty()) {
tasks_.Push([this, l] { BeginBlock(l->Initializer()); });
}
tasks_.Push([this, l] { control_stack_.Push(l); });
}
+void Validator::CheckLoopBody(const Loop* loop) {
+ // If the body block has parameters, there must be an initializer block.
+ if (!loop->Body()->Params().IsEmpty()) {
+ if (!loop->HasInitializer()) {
+ AddError(loop) << "loop with body block parameters must have an initializer";
+ }
+ }
+}
+
void Validator::CheckLoopContinuing(const Loop* loop) {
if (!loop->HasContinuing()) {
return;
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 39029fb..870d345 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -528,6 +528,7 @@
auto* p = b.BlockParam("my_param", ty.f32());
b.Append(f->Block(), [&] {
auto* l = b.Loop();
+ b.Append(l->Initializer(), [&] { b.NextIteration(l, nullptr); });
l->Body()->SetParams({p});
b.Append(l->Body(), [&] { b.ExitLoop(l); });
b.Return(f);
@@ -538,15 +539,18 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:4:12 error: destroyed parameter found in block parameter list
- $B2 (%my_param:f32): { # body
+ R"(:7:12 error: destroyed parameter found in block parameter list
+ $B3 (%my_param:f32): { # body
^^^^^^^^^
note: # Disassembly
%my_func = func():void {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%my_param:f32): { # body
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration undef # -> $B3
+ }
+ $B3 (%my_param:f32): { # body
exit_loop # loop_1
}
}
@@ -562,6 +566,7 @@
auto* p = b.BlockParam("my_param", ty.f32());
b.Append(f->Block(), [&] {
auto* l = b.Loop();
+ b.Append(l->Initializer(), [&] { b.NextIteration(l, nullptr); });
l->Body()->SetParams({p});
b.Append(l->Body(), [&] { b.ExitLoop(l); });
b.Return(f);
@@ -572,15 +577,18 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:4:12 error: block parameter has nullptr parent block
- $B2 (%my_param:f32): { # body
+ R"(:7:12 error: block parameter has nullptr parent block
+ $B3 (%my_param:f32): { # body
^^^^^^^^^
note: # Disassembly
%my_func = func():void {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%my_param:f32): { # body
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration undef # -> $B3
+ }
+ $B3 (%my_param:f32): { # body
exit_loop # loop_1
}
}
@@ -596,6 +604,7 @@
auto* p = b.BlockParam("my_param", ty.f32());
b.Append(f->Block(), [&] {
auto* l = b.Loop();
+ b.Append(l->Initializer(), [&] { b.NextIteration(l, nullptr); });
l->Body()->SetParams({p});
b.Append(l->Body(), [&] { b.Continue(l, p); });
l->Continuing()->SetParams({p});
@@ -606,23 +615,26 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:4:12 error: block parameter has incorrect parent block
- $B2 (%my_param:f32): { # body
+ R"(:7:12 error: block parameter has incorrect parent block
+ $B3 (%my_param:f32): { # body
^^^^^^^^^
-:7:7 note: parent block declared here
- $B3 (%my_param:f32): { # continuing
+:10:7 note: parent block declared here
+ $B4 (%my_param:f32): { # continuing
^^^^^^^^^^^^^^^^^^^
note: # Disassembly
%my_func = func():void {
$B1: {
- loop [b: $B2, c: $B3] { # loop_1
- $B2 (%my_param:f32): { # body
- continue %my_param # -> $B3
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ next_iteration undef # -> $B3
}
- $B3 (%my_param:f32): { # continuing
- next_iteration %my_param # -> $B2
+ $B3 (%my_param:f32): { # body
+ continue %my_param # -> $B4
+ }
+ $B4 (%my_param:f32): { # continuing
+ next_iteration %my_param # -> $B3
}
}
ret
@@ -3179,6 +3191,40 @@
ASSERT_EQ(res, Success);
}
+TEST_F(IR_ValidatorTest, LoopBodyParamsWithoutInitializer) {
+ auto* f = b.Function("my_func", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* loop = b.Loop();
+ loop->Body()->SetParams({b.BlockParam<i32>(), b.BlockParam<i32>()});
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop); });
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:3:5 error: loop: loop with body block parameters must have an initializer
+ loop [b: $B2] { # loop_1
+ ^^^^^^^^^^^^^
+
+:2:3 note: in block
+ $B1: {
+ ^^^
+
+note: # Disassembly
+%my_func = func():void {
+ $B1: {
+ loop [b: $B2] { # loop_1
+ $B2 (%2:i32, %3:i32): { # body
+ exit_loop # loop_1
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, ContinuingUseValueBeforeContinue) {
auto* f = b.Function("my_func", ty.void_());
auto* value = b.Let("value", 1_i);
@@ -3273,8 +3319,8 @@
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
R"(:8:9 error: break_if: provides 2 values but 'loop' block $B2 expects 0 values
- break_if true next_iteration: [ 1i ] # -> [t: exit_loop loop_1, f: $B2]
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ break_if true next_iteration: [ 1i, 2i ] # -> [t: exit_loop loop_1, f: $B2]
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:7:7 note: in block
$B3: { # continuing
@@ -3292,7 +3338,7 @@
continue # -> $B3
}
$B3: { # continuing
- break_if true next_iteration: [ 1i ] # -> [t: exit_loop loop_1, f: $B2]
+ break_if true next_iteration: [ 1i, 2i ] # -> [t: exit_loop loop_1, f: $B2]
}
}
ret
@@ -3306,6 +3352,7 @@
b.Append(f->Block(), [&] {
auto* loop = b.Loop();
loop->Body()->SetParams({b.BlockParam<i32>(), b.BlockParam<i32>()});
+ b.Append(loop->Initializer(), [&] { b.NextIteration(loop, nullptr, nullptr); });
b.Append(loop->Body(), [&] { b.Continue(loop); });
b.Append(loop->Continuing(), [&] { b.BreakIf(loop, true, Empty, Empty); });
b.Return(f);
@@ -3314,27 +3361,30 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:8:9 error: break_if: provides 0 values but 'loop' block $B2 expects 2 values
- break_if true # -> [t: exit_loop loop_1, f: $B2]
+ R"(:11:9 error: break_if: provides 0 values but 'loop' block $B3 expects 2 values
+ break_if true # -> [t: exit_loop loop_1, f: $B3]
^^^^^^^^^^^^^
-:7:7 note: in block
- $B3: { # continuing
+:10:7 note: in block
+ $B4: { # continuing
^^^
-:4:7 note: 'loop' block $B2 declared here
- $B2 (%2:i32, %3:i32): { # body
+:7:7 note: 'loop' block $B3 declared here
+ $B3 (%2:i32, %3:i32): { # body
^^^^^^^^^^^^^^^^^^^^
note: # Disassembly
%my_func = func():void {
$B1: {
- loop [b: $B2, c: $B3] { # loop_1
- $B2 (%2:i32, %3:i32): { # body
- continue # -> $B3
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ next_iteration undef, undef # -> $B3
}
- $B3: { # continuing
- break_if true # -> [t: exit_loop loop_1, f: $B2]
+ $B3 (%2:i32, %3:i32): { # body
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ break_if true # -> [t: exit_loop loop_1, f: $B3]
}
}
ret
@@ -3349,6 +3399,8 @@
auto* loop = b.Loop();
loop->Body()->SetParams(
{b.BlockParam<i32>(), b.BlockParam<f32>(), b.BlockParam<u32>(), b.BlockParam<bool>()});
+ b.Append(loop->Initializer(),
+ [&] { b.NextIteration(loop, nullptr, nullptr, nullptr, nullptr); });
b.Append(loop->Body(), [&] { b.Continue(loop); });
b.Append(loop->Continuing(),
[&] { b.BreakIf(loop, true, b.Values(1_i, 2_i, 3_f, false), Empty); });
@@ -3359,39 +3411,42 @@
ASSERT_NE(res, Success);
EXPECT_EQ(
res.Failure().reason.Str(),
- R"(:8:45 error: break_if: operand with type 'i32' does not match 'loop' block $B2 target type 'f32'
- break_if true next_iteration: [ 1i, 2i, 3.0f ] # -> [t: exit_loop loop_1, f: $B2]
+ R"(:11:45 error: break_if: operand with type 'i32' does not match 'loop' block $B3 target type 'f32'
+ break_if true next_iteration: [ 1i, 2i, 3.0f, false ] # -> [t: exit_loop loop_1, f: $B3]
^^
-:7:7 note: in block
- $B3: { # continuing
+:10:7 note: in block
+ $B4: { # continuing
^^^
-:4:20 note: %3 declared here
- $B2 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
+:7:20 note: %3 declared here
+ $B3 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
^^
-:8:49 error: break_if: operand with type 'f32' does not match 'loop' block $B2 target type 'u32'
- break_if true next_iteration: [ 1i, 2i, 3.0f ] # -> [t: exit_loop loop_1, f: $B2]
+:11:49 error: break_if: operand with type 'f32' does not match 'loop' block $B3 target type 'u32'
+ break_if true next_iteration: [ 1i, 2i, 3.0f, false ] # -> [t: exit_loop loop_1, f: $B3]
^^^^
-:7:7 note: in block
- $B3: { # continuing
+:10:7 note: in block
+ $B4: { # continuing
^^^
-:4:28 note: %4 declared here
- $B2 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
+:7:28 note: %4 declared here
+ $B3 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
^^
note: # Disassembly
%my_func = func():void {
$B1: {
- loop [b: $B2, c: $B3] { # loop_1
- $B2 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
- continue # -> $B3
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ next_iteration undef, undef, undef, undef # -> $B3
}
- $B3: { # continuing
- break_if true next_iteration: [ 1i, 2i, 3.0f ] # -> [t: exit_loop loop_1, f: $B2]
+ $B3 (%2:i32, %3:f32, %4:u32, %5:bool): { # body
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ break_if true next_iteration: [ 1i, 2i, 3.0f, false ] # -> [t: exit_loop loop_1, f: $B3]
}
}
ret
@@ -3406,6 +3461,8 @@
auto* loop = b.Loop();
loop->Body()->SetParams(
{b.BlockParam<i32>(), b.BlockParam<f32>(), b.BlockParam<u32>(), b.BlockParam<bool>()});
+ b.Append(loop->Initializer(),
+ [&] { b.NextIteration(loop, nullptr, nullptr, nullptr, nullptr); });
b.Append(loop->Body(), [&] { b.Continue(loop); });
b.Append(loop->Continuing(),
[&] { b.BreakIf(loop, true, b.Values(1_i, 2_f, 3_u, false), Empty); });
diff --git a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
index 1cbcee4..29abc93 100644
--- a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
@@ -170,11 +170,51 @@
manager.Add<ast::transform::Robustness>();
}
+ if (!options.disable_workgroup_init) {
+ // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
+ // ZeroInitWorkgroupMemory may inject new builtin parameters.
+ manager.Add<ast::transform::ZeroInitWorkgroupMemory>();
+ }
+
+ manager.Add<ast::transform::RemovePhonies>();
+
+ // TextureBuiltinsFromUniform must come before CombineSamplers to preserve texture binding point
+ // info, instead of combined sampler binding point. As a result, TextureBuiltinsFromUniform also
+ // comes before BindingRemapper so the binding point info it reflects is before remapping.
+ manager.Add<TextureBuiltinsFromUniform>();
+ data.Add<TextureBuiltinsFromUniform::Config>(
+ options.texture_builtins_from_uniform.ubo_binding,
+ options.texture_builtins_from_uniform.ubo_bindingpoint_ordering);
+
// Note: it is more efficient for MultiplanarExternalTexture to come after Robustness
+ // Must come before builtin polyfills
data.Add<ast::transform::MultiplanarExternalTexture::NewBindingPoints>(
options.external_texture_options.bindings_map);
manager.Add<ast::transform::MultiplanarExternalTexture>();
+ // Must be after multiplanar and must be before OffsetFirstindex
+ manager.Add<ast::transform::AddBlockAttribute>();
+
+ // This must come before ClampFragDepth as the AddBlockAttribute will change around the struct
+ // that gets created for the push constants and we end up with the `inner` structure sitting at
+ // the same offset we want to place the first_instance value.
+ manager.Add<ast::transform::OffsetFirstIndex>();
+ data.Add<ast::transform::OffsetFirstIndex::Config>(options.first_vertex_offset,
+ options.first_instance_offset);
+
+ // ClampFragDepth must come before CanonicalizeEntryPointIO, or the assignments to FragDepth are
+ // lost
+ manager.Add<ast::transform::ClampFragDepth>();
+ data.Add<ast::transform::ClampFragDepth::Config>(options.depth_range_offsets);
+
+ // CanonicalizeEntryPointIO must come after Robustness
+ manager.Add<ast::transform::CanonicalizeEntryPointIO>();
+ data.Add<ast::transform::CanonicalizeEntryPointIO::Config>(
+ ast::transform::CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
+
+ // DemoteToHelper must come after PromoteSideEffectsToDecl and ExpandCompoundAssignment.
+ manager.Add<ast::transform::DemoteToHelper>();
+
{ // Builtin polyfills
ast::transform::BuiltinPolyfill::Builtins polyfills;
polyfills.acosh = ast::transform::BuiltinPolyfill::Level::kRangeCheck;
@@ -201,48 +241,21 @@
manager.Add<ast::transform::DirectVariableAccess>();
- if (!options.disable_workgroup_init) {
- // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
- // ZeroInitWorkgroupMemory may inject new builtin parameters.
- manager.Add<ast::transform::ZeroInitWorkgroupMemory>();
- }
-
- manager.Add<ast::transform::AddBlockAttribute>();
-
- manager.Add<ast::transform::OffsetFirstIndex>();
-
- // ClampFragDepth must come before CanonicalizeEntryPointIO, or the assignments to FragDepth are
- // lost
- manager.Add<ast::transform::ClampFragDepth>();
-
- // CanonicalizeEntryPointIO must come after Robustness
- manager.Add<ast::transform::CanonicalizeEntryPointIO>();
-
- // PadStructs must come after CanonicalizeEntryPointIO
- manager.Add<PadStructs>();
-
- // DemoteToHelper must come after PromoteSideEffectsToDecl and ExpandCompoundAssignment.
- manager.Add<ast::transform::DemoteToHelper>();
-
- manager.Add<ast::transform::RemovePhonies>();
-
- // TextureBuiltinsFromUniform must come before CombineSamplers to preserve texture binding point
- // info, instead of combined sampler binding point. As a result, TextureBuiltinsFromUniform also
- // comes before BindingRemapper so the binding point info it reflects is before remapping.
- manager.Add<TextureBuiltinsFromUniform>();
- data.Add<TextureBuiltinsFromUniform::Config>(
- options.texture_builtins_from_uniform.ubo_binding,
- options.texture_builtins_from_uniform.ubo_bindingpoint_ordering);
-
+ // Must come after builtin polyfills (specifically texture_sample_base_clamp_to_edge_2d_f32)
data.Add<CombineSamplersInfo>(options.combined_samplers_info);
manager.Add<CombineSamplers>();
+ // Must come after CombineSamplers
data.Add<ast::transform::BindingRemapper::Remappings>(
options.binding_remapper_options.binding_points,
std::unordered_map<BindingPoint, core::Access>{},
/* allow_collisions */ true);
manager.Add<ast::transform::BindingRemapper>();
+ // PadStructs must come after CanonicalizeEntryPointIO and CombineSamplers
+ manager.Add<PadStructs>();
+
+ // Promote initializers must come after binding polyfill
manager.Add<ast::transform::PromoteInitializersToLet>();
manager.Add<ast::transform::RemoveContinueInSwitch>();
manager.Add<ast::transform::AddEmptyEntryPoint>();
@@ -254,13 +267,6 @@
manager.Add<ast::transform::SimplifyPointers>();
- data.Add<ast::transform::CanonicalizeEntryPointIO::Config>(
- ast::transform::CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
-
- data.Add<ast::transform::OffsetFirstIndex::Config>(std::nullopt, options.first_instance_offset);
-
- data.Add<ast::transform::ClampFragDepth::Config>(options.depth_range_offsets);
-
SanitizedResult result;
ast::transform::DataMap outputs;
result.program = manager.Run(in, data, outputs);
diff --git a/src/tint/lang/glsl/writer/ast_printer/member_accessor_test.cc b/src/tint/lang/glsl/writer/ast_printer/member_accessor_test.cc
index 302d899..1e4ee0c 100644
--- a/src/tint/lang/glsl/writer/ast_printer/member_accessor_test.cc
+++ b/src/tint/lang/glsl/writer/ast_printer/member_accessor_test.cc
@@ -312,13 +312,13 @@
Data inner;
} data;
-void assign_and_preserve_padding_data_b(mat2x3 value) {
+void assign_and_preserve_padding_data_inner_b(mat2x3 value) {
data.inner.b[0] = value[0u];
data.inner.b[1] = value[1u];
}
void tint_symbol() {
- assign_and_preserve_padding_data_b(mat2x3(vec3(0.0f), vec3(0.0f)));
+ assign_and_preserve_padding_data_inner_b(mat2x3(vec3(0.0f), vec3(0.0f)));
}
void main() {
diff --git a/src/tint/lang/glsl/writer/common/options.h b/src/tint/lang/glsl/writer/common/options.h
index f6f114a..3ba7540 100644
--- a/src/tint/lang/glsl/writer/common/options.h
+++ b/src/tint/lang/glsl/writer/common/options.h
@@ -156,6 +156,9 @@
/// Options used in the binding mappings for external textures
ExternalTextureOptions external_texture_options = {};
+ /// Offset of the firstVertex push constant.
+ std::optional<int32_t> first_vertex_offset;
+
/// Offset of the firstInstance push constant.
std::optional<int32_t> first_instance_offset;
@@ -176,6 +179,7 @@
combined_samplers_info,
binding_remapper_options,
external_texture_options,
+ first_vertex_offset,
first_instance_offset,
depth_range_offsets,
texture_builtins_from_uniform);
diff --git a/src/tint/lang/msl/writer/function_test.cc b/src/tint/lang/msl/writer/function_test.cc
index 21afd3a..197fe70 100644
--- a/src/tint/lang/msl/writer/function_test.cc
+++ b/src/tint/lang/msl/writer/function_test.cc
@@ -28,6 +28,8 @@
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/msl/writer/helper_test.h"
+using namespace tint::core::fluent_types; // NOLINT
+
namespace tint::msl::writer {
namespace {
@@ -43,34 +45,58 @@
}
TEST_F(MslWriterTest, EntryPointParameterBufferBindingPoint) {
- auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
- auto* storage = b.FunctionParam("storage", ty.ptr(core::AddressSpace::kStorage, ty.i32()));
- auto* uniform = b.FunctionParam("uniform", ty.ptr(core::AddressSpace::kUniform, ty.i32()));
+ auto* storage = b.Var("storage_var", ty.ptr(core::AddressSpace::kStorage, ty.i32()));
+ auto* uniform = b.Var("uniform_var", ty.ptr(core::AddressSpace::kUniform, ty.i32()));
storage->SetBindingPoint(0, 1);
uniform->SetBindingPoint(0, 2);
- func->SetParams({storage, uniform});
- func->Block()->Append(b.Return(func));
+ mod.root_block->Append(storage);
+ mod.root_block->Append(uniform);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Load(storage);
+ b.Load(uniform);
+ b.Return(func);
+ });
ASSERT_TRUE(Generate()) << err_ << output_.msl;
- EXPECT_EQ(output_.msl, MetalHeader() + R"(
-fragment void foo(device int* storage [[buffer(1)]], const constant int* uniform [[buffer(2)]]) {
+ EXPECT_EQ(output_.msl, MetalHeader() + R"(struct tint_module_vars_struct {
+ device int* storage_var;
+ const constant int* uniform_var;
+};
+
+fragment void foo(device int* storage_var [[buffer(1)]], const constant int* uniform_var [[buffer(2)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.storage_var=storage_var, .uniform_var=uniform_var};
}
)");
}
TEST_F(MslWriterTest, EntryPointParameterHandleBindingPoint) {
auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
- auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
- auto* texture = b.FunctionParam("texture", t);
- auto* sampler = b.FunctionParam("sampler", ty.sampler());
+ auto* texture = b.Var("texture", ty.ptr<handle>(t));
+ auto* sampler = b.Var("sampler", ty.ptr<handle>(ty.sampler()));
texture->SetBindingPoint(0, 1);
sampler->SetBindingPoint(0, 2);
- func->SetParams({texture, sampler});
- func->Block()->Append(b.Return(func));
+ mod.root_block->Append(texture);
+ mod.root_block->Append(sampler);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Load(texture);
+ b.Load(sampler);
+ b.Return(func);
+ });
ASSERT_TRUE(Generate()) << err_ << output_.msl;
- EXPECT_EQ(output_.msl, MetalHeader() + R"(
+ EXPECT_EQ(output_.msl, R"(#include <metal_stdlib>
+using namespace metal;
+struct tint_module_vars_struct {
+ texture2d<float, access::sample> texture;
+ sampler sampler;
+};
+
fragment void foo(texture2d<float, access::sample> texture [[texture(1)]], sampler sampler [[sampler(2)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.texture=texture, .sampler=sampler};
}
)");
}
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index b2de07b..71b31a0 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -289,7 +289,6 @@
}
++i;
- // TODO(dsinclair): Handle parameter attributes
EmitType(out, param->Type());
out << " ";
@@ -301,41 +300,15 @@
out << NameOf(param);
- if (param->Builtin().has_value()) {
- out << " [[";
- switch (param->Builtin().value()) {
- case core::BuiltinValue::kFrontFacing:
- out << "front_facing";
- break;
- case core::BuiltinValue::kGlobalInvocationId:
- out << "thread_position_in_grid";
- break;
- case core::BuiltinValue::kLocalInvocationId:
- out << "thread_position_in_threadgroup";
- break;
- case core::BuiltinValue::kLocalInvocationIndex:
- out << "thread_index_in_threadgroup";
- break;
- case core::BuiltinValue::kNumWorkgroups:
- out << "threadgroups_per_grid";
- break;
- case core::BuiltinValue::kPosition:
- out << "position";
- break;
- case core::BuiltinValue::kSampleIndex:
- out << "sample_id";
- break;
- case core::BuiltinValue::kSampleMask:
- out << "sample_mask";
- break;
- case core::BuiltinValue::kWorkgroupId:
- out << "threadgroup_position_in_grid";
- break;
+ if (auto builtin = param->Builtin()) {
+ auto name = BuiltinToAttribute(builtin.value());
+ TINT_ASSERT(!name.empty());
+ out << " [[" << name << "]]";
+ }
- default:
- break;
- }
- out << "]]";
+ if (param->Type()->Is<core::type::Struct>() &&
+ func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
+ out << " [[stage_in]]";
}
auto ptr = param->Type()->As<core::type::Pointer>();
diff --git a/src/tint/lang/msl/writer/raise/BUILD.bazel b/src/tint/lang/msl/writer/raise/BUILD.bazel
index 1c1d899..0373b09 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.bazel
+++ b/src/tint/lang/msl/writer/raise/BUILD.bazel
@@ -42,11 +42,13 @@
"builtin_polyfill.cc",
"module_scope_vars.cc",
"raise.cc",
+ "shader_io.cc",
],
hdrs = [
"builtin_polyfill.h",
"module_scope_vars.h",
"raise.h",
+ "shader_io.h",
],
deps = [
"//src/tint/api/common",
@@ -89,6 +91,7 @@
srcs = [
"builtin_polyfill_test.cc",
"module_scope_vars_test.cc",
+ "shader_io_test.cc",
],
deps = [
"//src/tint/api/common",
diff --git a/src/tint/lang/msl/writer/raise/BUILD.cmake b/src/tint/lang/msl/writer/raise/BUILD.cmake
index 396344d..1c0a709 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.cmake
+++ b/src/tint/lang/msl/writer/raise/BUILD.cmake
@@ -47,6 +47,8 @@
lang/msl/writer/raise/module_scope_vars.h
lang/msl/writer/raise/raise.cc
lang/msl/writer/raise/raise.h
+ lang/msl/writer/raise/shader_io.cc
+ lang/msl/writer/raise/shader_io.h
)
tint_target_add_dependencies(tint_lang_msl_writer_raise lib
@@ -93,6 +95,7 @@
tint_add_target(tint_lang_msl_writer_raise_test test
lang/msl/writer/raise/builtin_polyfill_test.cc
lang/msl/writer/raise/module_scope_vars_test.cc
+ lang/msl/writer/raise/shader_io_test.cc
)
tint_target_add_dependencies(tint_lang_msl_writer_raise_test test
diff --git a/src/tint/lang/msl/writer/raise/BUILD.gn b/src/tint/lang/msl/writer/raise/BUILD.gn
index afc543c..b7f3e7b 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.gn
+++ b/src/tint/lang/msl/writer/raise/BUILD.gn
@@ -50,6 +50,8 @@
"module_scope_vars.h",
"raise.cc",
"raise.h",
+ "shader_io.cc",
+ "shader_io.h",
]
deps = [
"${tint_src_dir}/api/common",
@@ -90,6 +92,7 @@
sources = [
"builtin_polyfill_test.cc",
"module_scope_vars_test.cc",
+ "shader_io_test.cc",
]
deps = [
"${tint_src_dir}:gmock_and_gtest",
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 09caa3c..ad8cbfe 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -44,6 +44,7 @@
#include "src/tint/lang/msl/writer/common/option_helpers.h"
#include "src/tint/lang/msl/writer/raise/builtin_polyfill.h"
#include "src/tint/lang/msl/writer/raise/module_scope_vars.h"
+#include "src/tint/lang/msl/writer/raise/shader_io.h"
namespace tint::msl::writer {
@@ -108,6 +109,7 @@
// DemoteToHelper must come before any transform that introduces non-core instructions.
RUN_TRANSFORM(core::ir::transform::DemoteToHelper);
+ RUN_TRANSFORM(raise::ShaderIO, raise::ShaderIOConfig{options.emit_vertex_point_size});
RUN_TRANSFORM(raise::ModuleScopeVars);
RUN_TRANSFORM(core::ir::transform::ValueToLet);
RUN_TRANSFORM(raise::BuiltinPolyfill);
diff --git a/src/tint/lang/msl/writer/raise/shader_io.cc b/src/tint/lang/msl/writer/raise/shader_io.cc
new file mode 100644
index 0000000..c4c8839
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/shader_io.cc
@@ -0,0 +1,197 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/msl/writer/raise/shader_io.h"
+
+#include <memory>
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/transform/shader_io.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::msl::writer::raise {
+
+namespace {
+
+/// State that persists across the whole module and can be shared between entry points.
+struct PerModuleState {
+ /// The frag_depth clamp arguments.
+ core::ir::Value* frag_depth_clamp_args = nullptr;
+};
+
+/// PIMPL state for the parts of the shader IO transform specific to MSL.
+/// For MSL, we take builtin inputs as entry point parameters, move non-builtin inputs to a struct
+/// passed as an entry point parameter, and wrap outputs in a structure returned by the entry point.
+struct StateImpl : core::ir::transform::ShaderIOBackendState {
+ /// The configuration options.
+ const ShaderIOConfig& config;
+
+ /// The per-module state object.
+ PerModuleState& module_state;
+
+ /// The input parameters of the entry point.
+ Vector<core::ir::FunctionParam*, 4> input_params;
+
+ /// The list of input indices which map to parameter and optional struct member accesses.
+ struct InputIndex {
+ const uint32_t param_index;
+ const uint32_t member_index;
+ };
+ Vector<InputIndex, 4> input_indices;
+
+ /// The output struct type.
+ core::type::Struct* output_struct = nullptr;
+
+ /// The output values to return from the entry point.
+ Vector<core::ir::Value*, 4> output_values;
+
+ /// Constructor
+ StateImpl(core::ir::Module& mod,
+ core::ir::Function* f,
+ const ShaderIOConfig& cfg,
+ PerModuleState& mod_state)
+ : ShaderIOBackendState(mod, f), config(cfg), module_state(mod_state) {}
+
+ /// Destructor
+ ~StateImpl() override {}
+
+ /// @copydoc ShaderIO::BackendState::FinalizeInputs
+ Vector<core::ir::FunctionParam*, 4> FinalizeInputs() override {
+ Vector<core::type::Manager::StructMemberDesc, 4> input_struct_members;
+ core::ir::FunctionParam* input_struct_param = nullptr;
+ uint32_t input_struct_param_index = 0xffffffff;
+
+ for (auto input : inputs) {
+ if (input.attributes.builtin) {
+ auto* param = b.FunctionParam(input.name.Name(), input.type);
+ param->SetInvariant(input.attributes.invariant);
+ param->SetBuiltin(input.attributes.builtin.value());
+ input_indices.Push(InputIndex{static_cast<uint32_t>(input_params.Length()), 0u});
+ input_params.Push(param);
+ } else if (input.attributes.location) {
+ if (!input_struct_param) {
+ input_struct_param = b.FunctionParam("inputs", nullptr);
+ input_struct_param_index = static_cast<uint32_t>(input_params.Length());
+ input_params.Push(input_struct_param);
+ }
+ input_indices.Push(
+ InputIndex{input_struct_param_index,
+ static_cast<uint32_t>(input_struct_members.Length())});
+ input_struct_members.Push(input);
+ }
+ }
+
+ if (!input_struct_members.IsEmpty()) {
+ auto* input_struct =
+ ty.Struct(ir.symbols.New(ir.NameOf(func).Name() + "_inputs"), input_struct_members);
+ switch (func->Stage()) {
+ case core::ir::Function::PipelineStage::kFragment:
+ input_struct->AddUsage(core::type::PipelineStageUsage::kFragmentInput);
+ break;
+ case core::ir::Function::PipelineStage::kVertex:
+ input_struct->AddUsage(core::type::PipelineStageUsage::kVertexInput);
+ break;
+ case core::ir::Function::PipelineStage::kCompute:
+ case core::ir::Function::PipelineStage::kUndefined:
+ TINT_UNREACHABLE();
+ }
+ input_struct_param->SetType(input_struct);
+ }
+
+ return input_params;
+ }
+
+ /// @copydoc ShaderIO::BackendState::FinalizeOutputs
+ const core::type::Type* FinalizeOutputs() override {
+ if (outputs.IsEmpty()) {
+ return ty.void_();
+ }
+ output_struct = ty.Struct(ir.symbols.New(ir.NameOf(func).Name() + "_outputs"), outputs);
+ switch (func->Stage()) {
+ case core::ir::Function::PipelineStage::kFragment:
+ output_struct->AddUsage(core::type::PipelineStageUsage::kFragmentOutput);
+ break;
+ case core::ir::Function::PipelineStage::kVertex:
+ output_struct->AddUsage(core::type::PipelineStageUsage::kVertexOutput);
+ break;
+ case core::ir::Function::PipelineStage::kCompute:
+ case core::ir::Function::PipelineStage::kUndefined:
+ TINT_UNREACHABLE();
+ }
+ output_values.Resize(outputs.Length());
+ return output_struct;
+ }
+
+ /// @copydoc ShaderIO::BackendState::GetInput
+ core::ir::Value* GetInput(core::ir::Builder& builder, uint32_t idx) override {
+ auto index = input_indices[idx];
+ auto* param = input_params[index.param_index];
+ if (auto* str = param->Type()->As<core::type::Struct>()) {
+ return builder.Access(inputs[idx].type, param, u32(index.member_index))->Result(0);
+ } else {
+ return param;
+ }
+ }
+
+ /// @copydoc ShaderIO::BackendState::SetOutput
+ void SetOutput(core::ir::Builder&, uint32_t idx, core::ir::Value* value) override {
+ output_values[idx] = value;
+ }
+
+ /// @copydoc ShaderIO::BackendState::MakeReturnValue
+ core::ir::Value* MakeReturnValue(core::ir::Builder& builder) override {
+ if (!output_struct) {
+ return nullptr;
+ }
+ return builder.Construct(output_struct, std::move(output_values))->Result(0);
+ }
+
+ /// @copydoc ShaderIO::BackendState::NeedsVertexPointSize
+ bool NeedsVertexPointSize() const override { return config.emit_vertex_point_size; }
+};
+} // namespace
+
+Result<SuccessType> ShaderIO(core::ir::Module& ir, const ShaderIOConfig& config) {
+ auto result = ValidateAndDumpIfNeeded(ir, "ShaderIO transform");
+ if (result != Success) {
+ return result;
+ }
+
+ PerModuleState module_state;
+ core::ir::transform::RunShaderIOBase(ir, [&](core::ir::Module& mod, core::ir::Function* func) {
+ return std::make_unique<StateImpl>(mod, func, config, module_state);
+ });
+
+ return Success;
+}
+
+} // namespace tint::msl::writer::raise
diff --git a/src/tint/lang/msl/writer/raise/shader_io.h b/src/tint/lang/msl/writer/raise/shader_io.h
new file mode 100644
index 0000000..4bd7ab0
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/shader_io.h
@@ -0,0 +1,54 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_MSL_WRITER_RAISE_SHADER_IO_H_
+#define SRC_TINT_LANG_MSL_WRITER_RAISE_SHADER_IO_H_
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::msl::writer::raise {
+
+/// ShaderIOConfig describes the set of configuration options for the ShaderIO transform.
+struct ShaderIOConfig {
+ /// true if a vertex point size builtin output should be added
+ bool emit_vertex_point_size = false;
+};
+
+/// ShaderIO is a transform that prepares entry point inputs and outputs for MSL codegen.
+/// @param module the module to transform
+/// @param config the configuration
+/// @returns success or failure
+Result<SuccessType> ShaderIO(core::ir::Module& module, const ShaderIOConfig& config);
+
+} // namespace tint::msl::writer::raise
+
+#endif // SRC_TINT_LANG_MSL_WRITER_RAISE_SHADER_IO_H_
diff --git a/src/tint/lang/msl/writer/raise/shader_io_test.cc b/src/tint/lang/msl/writer/raise/shader_io_test.cc
new file mode 100644
index 0000000..478a774
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/shader_io_test.cc
@@ -0,0 +1,976 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/msl/writer/raise/shader_io.h"
+
+namespace tint::msl::writer::raise {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using MslWriter_ShaderIOTest = core::ir::transform::TransformTest;
+
+TEST_F(MslWriter_ShaderIOTest, NoInputsOrOutputs) {
+ auto* ep = b.Function("foo", ty.void_());
+ ep->SetStage(core::ir::Function::PipelineStage::kCompute);
+ ep->SetWorkgroupSize(1, 1, 1);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute @workgroup_size(1, 1, 1) func():void {
+ $B1: {
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, Parameters_NonStruct) {
+ auto* ep = b.Function("foo", ty.void_());
+ auto* front_facing = b.FunctionParam("front_facing", ty.bool_());
+ front_facing->SetBuiltin(core::BuiltinValue::kFrontFacing);
+ auto* position = b.FunctionParam("position", ty.vec4<f32>());
+ position->SetBuiltin(core::BuiltinValue::kPosition);
+ position->SetInvariant(true);
+ auto* color1 = b.FunctionParam("color1", ty.f32());
+ color1->SetLocation(0, {});
+ auto* color2 = b.FunctionParam("color2", ty.f32());
+ color2->SetLocation(1, core::Interpolation{core::InterpolationType::kLinear,
+ core::InterpolationSampling::kSample});
+
+ ep->SetParams({front_facing, position, color1, color2});
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+ b.Append(ep->Block(), [&] {
+ auto* ifelse = b.If(front_facing);
+ b.Append(ifelse->True(), [&] {
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @fragment func(%front_facing:bool [@front_facing], %position:vec4<f32> [@invariant, @position], %color1:f32 [@location(0)], %color2:f32 [@location(1), @interpolate(linear, sample)]):void {
+ $B1: {
+ if %front_facing [t: $B2] { # if_1
+ $B2: { # true
+ %6:f32 = add %color1, %color2
+ %7:vec4<f32> = mul %position, %6
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_inputs = struct @align(4) {
+ color1:f32 @offset(0), @location(0)
+ color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%foo_inner = func(%front_facing:bool, %position:vec4<f32>, %color1:f32, %color2:f32):void {
+ $B1: {
+ if %front_facing [t: $B2] { # if_1
+ $B2: { # true
+ %6:f32 = add %color1, %color2
+ %7:vec4<f32> = mul %position, %6
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func(%front_facing_1:bool [@front_facing], %position_1:vec4<f32> [@invariant, @position], %inputs:foo_inputs):void { # %front_facing_1: 'front_facing', %position_1: 'position'
+ $B3: {
+ %12:f32 = access %inputs, 0u
+ %13:f32 = access %inputs, 1u
+ %14:void = call %foo_inner, %front_facing_1, %position_1, %12, %13
+ ret
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, Parameters_Struct) {
+ auto* str_ty = ty.Struct(mod.symbols.New("Inputs"),
+ {
+ {
+ mod.symbols.New("front_facing"),
+ ty.bool_(),
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kFrontFacing,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kPosition,
+ /* interpolation */ std::nullopt,
+ /* invariant */ true,
+ },
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("color2"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 1u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */
+ core::Interpolation{
+ core::InterpolationType::kLinear,
+ core::InterpolationSampling::kSample,
+ },
+ /* invariant */ false,
+ },
+ },
+ });
+
+ auto* ep = b.Function("foo", ty.void_());
+ auto* str_param = b.FunctionParam("inputs", str_ty);
+ ep->SetParams({str_param});
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+ b.Append(ep->Block(), [&] {
+ auto* ifelse = b.If(b.Access(ty.bool_(), str_param, 0_i));
+ b.Append(ifelse->True(), [&] {
+ auto* position = b.Access(ty.vec4<f32>(), str_param, 1_i);
+ auto* color1 = b.Access(ty.f32(), str_param, 2_i);
+ auto* color2 = b.Access(ty.f32(), str_param, 3_i);
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+Inputs = struct @align(16) {
+ front_facing:bool @offset(0), @builtin(front_facing)
+ position:vec4<f32> @offset(16), @invariant, @builtin(position)
+ color1:f32 @offset(32), @location(0)
+ color2:f32 @offset(36), @location(1), @interpolate(linear, sample)
+}
+
+%foo = @fragment func(%inputs:Inputs):void {
+ $B1: {
+ %3:bool = access %inputs, 0i
+ if %3 [t: $B2] { # if_1
+ $B2: { # true
+ %4:vec4<f32> = access %inputs, 1i
+ %5:f32 = access %inputs, 2i
+ %6:f32 = access %inputs, 3i
+ %7:f32 = add %5, %6
+ %8:vec4<f32> = mul %4, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inputs = struct @align(16) {
+ front_facing:bool @offset(0)
+ position:vec4<f32> @offset(16)
+ color1:f32 @offset(32)
+ color2:f32 @offset(36)
+}
+
+foo_inputs = struct @align(4) {
+ Inputs_color1:f32 @offset(0), @location(0)
+ Inputs_color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%foo_inner = func(%inputs:Inputs):void {
+ $B1: {
+ %3:bool = access %inputs, 0i
+ if %3 [t: $B2] { # if_1
+ $B2: { # true
+ %4:vec4<f32> = access %inputs, 1i
+ %5:f32 = access %inputs, 2i
+ %6:f32 = access %inputs, 3i
+ %7:f32 = add %5, %6
+ %8:vec4<f32> = mul %4, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func(%Inputs_front_facing:bool [@front_facing], %Inputs_position:vec4<f32> [@invariant, @position], %inputs_1:foo_inputs):void { # %inputs_1: 'inputs'
+ $B3: {
+ %13:f32 = access %inputs_1, 0u
+ %14:f32 = access %inputs_1, 1u
+ %15:Inputs = construct %Inputs_front_facing, %Inputs_position, %13, %14
+ %16:void = call %foo_inner, %15
+ ret
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, Parameters_Mixed) {
+ auto* str_ty = ty.Struct(mod.symbols.New("Inputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kPosition,
+ /* interpolation */ std::nullopt,
+ /* invariant */ true,
+ },
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ });
+
+ auto* ep = b.Function("foo", ty.void_());
+ auto* front_facing = b.FunctionParam("front_facing", ty.bool_());
+ front_facing->SetBuiltin(core::BuiltinValue::kFrontFacing);
+ auto* str_param = b.FunctionParam("inputs", str_ty);
+ auto* color2 = b.FunctionParam("color2", ty.f32());
+ color2->SetLocation(1, core::Interpolation{core::InterpolationType::kLinear,
+ core::InterpolationSampling::kSample});
+
+ ep->SetParams({front_facing, str_param, color2});
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+ b.Append(ep->Block(), [&] {
+ auto* ifelse = b.If(front_facing);
+ b.Append(ifelse->True(), [&] {
+ auto* position = b.Access(ty.vec4<f32>(), str_param, 0_i);
+ auto* color1 = b.Access(ty.f32(), str_param, 1_i);
+ b.Multiply(ty.vec4<f32>(), position, b.Add(ty.f32(), color1, color2));
+ b.ExitIf(ifelse);
+ });
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+Inputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @invariant, @builtin(position)
+ color1:f32 @offset(16), @location(0)
+}
+
+%foo = @fragment func(%front_facing:bool [@front_facing], %inputs:Inputs, %color2:f32 [@location(1), @interpolate(linear, sample)]):void {
+ $B1: {
+ if %front_facing [t: $B2] { # if_1
+ $B2: { # true
+ %5:vec4<f32> = access %inputs, 0i
+ %6:f32 = access %inputs, 1i
+ %7:f32 = add %6, %color2
+ %8:vec4<f32> = mul %5, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color1:f32 @offset(16)
+}
+
+foo_inputs = struct @align(4) {
+ Inputs_color1:f32 @offset(0), @location(0)
+ color2:f32 @offset(4), @location(1), @interpolate(linear, sample)
+}
+
+%foo_inner = func(%front_facing:bool, %inputs:Inputs, %color2:f32):void {
+ $B1: {
+ if %front_facing [t: $B2] { # if_1
+ $B2: { # true
+ %5:vec4<f32> = access %inputs, 0i
+ %6:f32 = access %inputs, 1i
+ %7:f32 = add %6, %color2
+ %8:vec4<f32> = mul %5, %7
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%foo = @fragment func(%front_facing_1:bool [@front_facing], %Inputs_position:vec4<f32> [@invariant, @position], %inputs_1:foo_inputs):void { # %front_facing_1: 'front_facing', %inputs_1: 'inputs'
+ $B3: {
+ %13:f32 = access %inputs_1, 0u
+ %14:Inputs = construct %Inputs_position, %13
+ %15:f32 = access %inputs_1, 1u
+ %16:void = call %foo_inner, %front_facing_1, %14, %15
+ ret
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, ReturnValue_NonStructBuiltin) {
+ auto* ep = b.Function("foo", ty.vec4<f32>());
+ ep->SetReturnBuiltin(core::BuiltinValue::kPosition);
+ ep->SetReturnInvariant(true);
+ ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+ });
+
+ auto* src = R"(
+%foo = @vertex func():vec4<f32> [@invariant, @position] {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_outputs = struct @align(16) {
+ tint_symbol:vec4<f32> @offset(0), @invariant, @builtin(position)
+}
+
+%foo_inner = func():vec4<f32> {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+%foo = @vertex func():foo_outputs {
+ $B2: {
+ %4:vec4<f32> = call %foo_inner
+ %5:foo_outputs = construct %4
+ ret %5
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, ReturnValue_NonStructLocation) {
+ auto* ep = b.Function("foo", ty.vec4<f32>());
+ ep->SetReturnLocation(1u, {});
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+ });
+
+ auto* src = R"(
+%foo = @fragment func():vec4<f32> [@location(1)] {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_outputs = struct @align(16) {
+ tint_symbol:vec4<f32> @offset(0), @location(1)
+}
+
+%foo_inner = func():vec4<f32> {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+%foo = @fragment func():foo_outputs {
+ $B2: {
+ %4:vec4<f32> = call %foo_inner
+ %5:foo_outputs = construct %4
+ ret %5
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, ReturnValue_Struct) {
+ auto* str_ty = ty.Struct(mod.symbols.New("Outputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ ty.vec4<f32>(),
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kPosition,
+ /* interpolation */ std::nullopt,
+ /* invariant */ true,
+ },
+ },
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("color2"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 1u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */
+ core::Interpolation{
+ core::InterpolationType::kLinear,
+ core::InterpolationSampling::kSample,
+ },
+ /* invariant */ false,
+ },
+ },
+ });
+
+ auto* ep = b.Function("foo", str_ty);
+ ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(str_ty, b.Construct(ty.vec4<f32>(), 0_f), 0.25_f, 0.75_f));
+ });
+
+ auto* src = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @invariant, @builtin(position)
+ color1:f32 @offset(16), @location(0)
+ color2:f32 @offset(20), @location(1), @interpolate(linear, sample)
+}
+
+%foo = @vertex func():Outputs {
+ $B1: {
+ %2:vec4<f32> = construct 0.0f
+ %3:Outputs = construct %2, 0.25f, 0.75f
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color1:f32 @offset(16)
+ color2:f32 @offset(20)
+}
+
+foo_outputs = struct @align(16) {
+ Outputs_position:vec4<f32> @offset(0), @invariant, @builtin(position)
+ Outputs_color1:f32 @offset(16), @location(0)
+ Outputs_color2:f32 @offset(20), @location(1), @interpolate(linear, sample)
+}
+
+%foo_inner = func():Outputs {
+ $B1: {
+ %2:vec4<f32> = construct 0.0f
+ %3:Outputs = construct %2, 0.25f, 0.75f
+ ret %3
+ }
+}
+%foo = @vertex func():foo_outputs {
+ $B2: {
+ %5:Outputs = call %foo_inner
+ %6:vec4<f32> = access %5, 0u
+ %7:f32 = access %5, 1u
+ %8:f32 = access %5, 2u
+ %9:foo_outputs = construct %6, %7, %8
+ ret %9
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, ReturnValue_DualSourceBlending) {
+ auto* str_ty =
+ ty.Struct(mod.symbols.New("Output"), {
+ {
+ mod.symbols.New("color1"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ 0u,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("color2"),
+ ty.f32(),
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ 1u,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ });
+
+ auto* ep = b.Function("foo", str_ty);
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(str_ty, 0.25_f, 0.75_f));
+ });
+
+ auto* src = R"(
+Output = struct @align(4) {
+ color1:f32 @offset(0), @location(0)
+ color2:f32 @offset(4), @location(0)
+}
+
+%foo = @fragment func():Output {
+ $B1: {
+ %2:Output = construct 0.25f, 0.75f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Output = struct @align(4) {
+ color1:f32 @offset(0)
+ color2:f32 @offset(4)
+}
+
+foo_outputs = struct @align(4) {
+ Output_color1:f32 @offset(0), @location(0)
+ Output_color2:f32 @offset(4), @location(0)
+}
+
+%foo_inner = func():Output {
+ $B1: {
+ %2:Output = construct 0.25f, 0.75f
+ ret %2
+ }
+}
+%foo = @fragment func():foo_outputs {
+ $B2: {
+ %4:Output = call %foo_inner
+ %5:f32 = access %4, 0u
+ %6:f32 = access %4, 1u
+ %7:foo_outputs = construct %5, %6
+ ret %7
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, Struct_SharedByVertexAndFragment) {
+ auto* vec4f = ty.vec4<f32>();
+ auto* str_ty = ty.Struct(mod.symbols.New("Interface"),
+ {
+ {
+ mod.symbols.New("position"),
+ vec4f,
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kPosition,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("color"),
+ vec4f,
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ });
+
+ // Vertex shader.
+ {
+ auto* ep = b.Function("vert", str_ty);
+ ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+
+ b.Append(ep->Block(), [&] { //
+ auto* position = b.Construct(vec4f, 0_f);
+ auto* color = b.Construct(vec4f, 1_f);
+ b.Return(ep, b.Construct(str_ty, position, color));
+ });
+ }
+
+ // Fragment shader.
+ {
+ auto* ep = b.Function("frag", vec4f);
+ auto* inputs = b.FunctionParam("inputs", str_ty);
+ ep->SetStage(core::ir::Function::PipelineStage::kFragment);
+ ep->SetParams({inputs});
+ ep->SetReturnLocation(0u, {});
+
+ b.Append(ep->Block(), [&] { //
+ auto* position = b.Access(vec4f, inputs, 0_u);
+ auto* color = b.Access(vec4f, inputs, 1_u);
+ b.Return(ep, b.Add(vec4f, position, color));
+ });
+ }
+
+ auto* src = R"(
+Interface = struct @align(16) {
+ position:vec4<f32> @offset(0), @builtin(position)
+ color:vec4<f32> @offset(16), @location(0)
+}
+
+%vert = @vertex func():Interface {
+ $B1: {
+ %2:vec4<f32> = construct 0.0f
+ %3:vec4<f32> = construct 1.0f
+ %4:Interface = construct %2, %3
+ ret %4
+ }
+}
+%frag = @fragment func(%inputs:Interface):vec4<f32> [@location(0)] {
+ $B2: {
+ %7:vec4<f32> = access %inputs, 0u
+ %8:vec4<f32> = access %inputs, 1u
+ %9:vec4<f32> = add %7, %8
+ ret %9
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Interface = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color:vec4<f32> @offset(16)
+}
+
+vert_outputs = struct @align(16) {
+ Interface_position:vec4<f32> @offset(0), @builtin(position)
+ Interface_color:vec4<f32> @offset(16), @location(0)
+}
+
+frag_inputs = struct @align(16) {
+ Interface_color:vec4<f32> @offset(0), @location(0)
+}
+
+frag_outputs = struct @align(16) {
+ tint_symbol:vec4<f32> @offset(0), @location(0)
+}
+
+%vert_inner = func():Interface {
+ $B1: {
+ %2:vec4<f32> = construct 0.0f
+ %3:vec4<f32> = construct 1.0f
+ %4:Interface = construct %2, %3
+ ret %4
+ }
+}
+%frag_inner = func(%inputs:Interface):vec4<f32> {
+ $B2: {
+ %7:vec4<f32> = access %inputs, 0u
+ %8:vec4<f32> = access %inputs, 1u
+ %9:vec4<f32> = add %7, %8
+ ret %9
+ }
+}
+%vert = @vertex func():vert_outputs {
+ $B3: {
+ %11:Interface = call %vert_inner
+ %12:vec4<f32> = access %11, 0u
+ %13:vec4<f32> = access %11, 1u
+ %14:vert_outputs = construct %12, %13
+ ret %14
+ }
+}
+%frag = @fragment func(%Interface_position:vec4<f32> [@position], %inputs_1:frag_inputs):frag_outputs { # %inputs_1: 'inputs'
+ $B4: {
+ %18:vec4<f32> = access %inputs_1, 0u
+ %19:Interface = construct %Interface_position, %18
+ %20:vec4<f32> = call %frag_inner, %19
+ %21:frag_outputs = construct %20
+ ret %21
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, Struct_SharedWithBuffer) {
+ auto* vec4f = ty.vec4<f32>();
+ auto* str_ty = ty.Struct(mod.symbols.New("Outputs"),
+ {
+ {
+ mod.symbols.New("position"),
+ vec4f,
+ core::type::StructMemberAttributes{
+ /* location */ std::nullopt,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ core::BuiltinValue::kPosition,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ {
+ mod.symbols.New("color"),
+ vec4f,
+ core::type::StructMemberAttributes{
+ /* location */ 0u,
+ /* index */ std::nullopt,
+ /* color */ std::nullopt,
+ /* builtin */ std::nullopt,
+ /* interpolation */ std::nullopt,
+ /* invariant */ false,
+ },
+ },
+ });
+
+ auto* buffer = mod.root_block->Append(b.Var(ty.ptr(storage, str_ty, read)));
+
+ auto* ep = b.Function("vert", str_ty);
+ ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Load(buffer));
+ });
+
+ auto* src = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0), @builtin(position)
+ color:vec4<f32> @offset(16), @location(0)
+}
+
+$B1: { # root
+ %1:ptr<storage, Outputs, read> = var
+}
+
+%vert = @vertex func():Outputs {
+ $B2: {
+ %3:Outputs = load %1
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Outputs = struct @align(16) {
+ position:vec4<f32> @offset(0)
+ color:vec4<f32> @offset(16)
+}
+
+vert_outputs = struct @align(16) {
+ Outputs_position:vec4<f32> @offset(0), @builtin(position)
+ Outputs_color:vec4<f32> @offset(16), @location(0)
+}
+
+$B1: { # root
+ %1:ptr<storage, Outputs, read> = var
+}
+
+%vert_inner = func():Outputs {
+ $B2: {
+ %3:Outputs = load %1
+ ret %3
+ }
+}
+%vert = @vertex func():vert_outputs {
+ $B3: {
+ %5:Outputs = call %vert_inner
+ %6:vec4<f32> = access %5, 0u
+ %7:vec4<f32> = access %5, 1u
+ %8:vert_outputs = construct %6, %7
+ ret %8
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_ShaderIOTest, EmitVertexPointSize) {
+ auto* ep = b.Function("foo", ty.vec4<f32>());
+ ep->SetStage(core::ir::Function::PipelineStage::kVertex);
+ ep->SetReturnBuiltin(core::BuiltinValue::kPosition);
+
+ b.Append(ep->Block(), [&] { //
+ b.Return(ep, b.Construct(ty.vec4<f32>(), 0.5_f));
+ });
+
+ auto* src = R"(
+%foo = @vertex func():vec4<f32> [@position] {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+foo_outputs = struct @align(16) {
+ tint_symbol:vec4<f32> @offset(0), @builtin(position)
+ vertex_point_size:f32 @offset(16), @builtin(__point_size)
+}
+
+%foo_inner = func():vec4<f32> {
+ $B1: {
+ %2:vec4<f32> = construct 0.5f
+ ret %2
+ }
+}
+%foo = @vertex func():foo_outputs {
+ $B2: {
+ %4:vec4<f32> = call %foo_inner
+ %5:foo_outputs = construct %4, 1.0f
+ ret %5
+ }
+}
+)";
+
+ ShaderIOConfig config;
+ config.emit_vertex_point_size = true;
+ Run(ShaderIO, config);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::msl::writer::raise
diff --git a/src/tint/lang/msl/writer/writer_test.cc b/src/tint/lang/msl/writer/writer_test.cc
index 52896f7..f3ec09b 100644
--- a/src/tint/lang/msl/writer/writer_test.cc
+++ b/src/tint/lang/msl/writer/writer_test.cc
@@ -58,17 +58,16 @@
ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, R"(#include <metal_stdlib>
using namespace metal;
-struct tint_symbol_2 {
- int tint_symbol;
- int tint_symbol_1;
-};
struct tint_module_vars_struct {
threadgroup int* a;
threadgroup int* b;
};
+struct tint_symbol_2 {
+ int tint_symbol;
+ int tint_symbol_1;
+};
-kernel void foo(uint tint_local_index [[thread_index_in_threadgroup]], threadgroup tint_symbol_2* v [[threadgroup(0)]]) {
- tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.a=(&(*v).tint_symbol), .b=(&(*v).tint_symbol_1)};
+void foo_inner(uint tint_local_index, tint_module_vars_struct tint_module_vars) {
if ((tint_local_index == 0u)) {
(*tint_module_vars.a) = 0;
(*tint_module_vars.b) = 0;
@@ -78,6 +77,10 @@
}
kernel void bar() {
}
+kernel void foo(uint tint_local_index [[thread_index_in_threadgroup]], threadgroup tint_symbol_2* v [[threadgroup(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.a=(&(*v).tint_symbol), .b=(&(*v).tint_symbol_1)};
+ foo_inner(tint_local_index, tint_module_vars);
+}
)");
ASSERT_EQ(output_.workgroup_allocations.size(), 2u);
ASSERT_EQ(output_.workgroup_allocations.count("foo"), 1u);
diff --git a/src/tint/lang/spirv/writer/common/helper_test.h b/src/tint/lang/spirv/writer/common/helper_test.h
index 553b173..bedc457 100644
--- a/src/tint/lang/spirv/writer/common/helper_test.h
+++ b/src/tint/lang/spirv/writer/common/helper_test.h
@@ -36,6 +36,7 @@
#include "gtest/gtest.h"
#include "spirv-tools/libspirv.hpp"
#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/disassembly.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/depth_texture.h"
@@ -241,6 +242,10 @@
}
return nullptr;
}
+
+ /// Helper to dump the disassembly of the Tint IR module.
+ /// @returns the disassembly (with a leading newline)
+ std::string IR() { return "\n" + core::ir::Disassemble(mod).Plain(); }
};
using SpirvWriterTest = SpirvWriterTestHelperBase<testing::Test>;
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index f8ccb38..1631549 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -621,5 +621,112 @@
)");
}
+TEST_F(SpirvWriterTest, Loop_ExitValue) {
+ auto* func = b.Function("foo", ty.i32());
+ b.Append(func->Block(), [&] {
+ auto* result = b.InstructionResult(ty.i32());
+ auto* loop = b.Loop();
+ loop->SetResults(Vector{result});
+ b.Append(loop->Body(), [&] { //
+ b.ExitLoop(loop, 42_i);
+ });
+ b.Return(func, result);
+ });
+
+ EXPECT_EQ(IR(), R"(
+%foo = func():i32 {
+ $B1: {
+ %2:i32 = loop [b: $B2] { # loop_1
+ $B2: { # body
+ exit_loop 42i # loop_1
+ }
+ }
+ ret %2
+ }
+}
+)");
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %8
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ %9 = OpPhi %int %int_42 %5
+ OpReturnValue %9
+ OpFunctionEnd
+)");
+}
+
+TEST_F(SpirvWriterTest, Loop_ExitValue_BreakIf) {
+ auto* func = b.Function("foo", ty.i32());
+ b.Append(func->Block(), [&] {
+ auto* result = b.InstructionResult(ty.i32());
+ auto* loop = b.Loop();
+ loop->SetResults(Vector{result});
+ b.Append(loop->Body(), [&] { //
+ auto* if_ = b.If(false);
+ b.Append(if_->True(), [&] { //
+ b.ExitLoop(loop, 1_i);
+ });
+ b.Continue(loop);
+
+ b.Append(loop->Continuing(), [&] { //
+ b.BreakIf(loop, true, Empty, 42_i);
+ });
+ });
+ b.Return(func, result);
+ });
+
+ EXPECT_EQ(IR(), R"(
+%foo = func():i32 {
+ $B1: {
+ %2:i32 = loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ if false [t: $B4] { # if_1
+ $B4: { # true
+ exit_loop 1i # loop_1
+ }
+ }
+ continue # -> $B3
+ }
+ $B3: { # continuing
+ break_if true exit_loop: [ 42i ] # -> [t: exit_loop loop_1, f: $B2]
+ }
+ }
+ ret %2
+ }
+}
+)");
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpSelectionMerge %9 None
+ OpBranchConditional %false %10 %9
+ %10 = OpLabel
+ OpBranch %8
+ %9 = OpLabel
+ OpBranch %6
+ %6 = OpLabel
+ OpBranchConditional %true %8 %7
+ %8 = OpLabel
+ %14 = OpPhi %int %int_42 %6 %int_1 %10
+ OpReturnValue %14
+ OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index c2e7104..8a7ccc1 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -32,6 +32,7 @@
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
+#include "src/tint/lang/core/access.h"
#include "src/tint/lang/core/address_space.h"
#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/core/constant/scalar.h"
@@ -157,6 +158,10 @@
[&](const core::type::DepthMultisampledTexture* depth) {
return types.Get<core::type::MultisampledTexture>(depth->dim(), types.f32());
},
+ [&](const core::type::StorageTexture* st) -> const core::type::Type* {
+ return types.Get<core::type::StorageTexture>(st->dim(), st->texel_format(),
+ core::Access::kRead, st->type());
+ },
// Both sampler types are the same in SPIR-V.
[&](const core::type::Sampler* s) -> const core::type::Type* {
@@ -2186,11 +2191,13 @@
branches.Sort(); // Sort the branches by label to ensure deterministic output
// Also add phi nodes from implicit exit blocks.
- inst->ForeachBlock([&](core::ir::Block* block) {
- if (block->IsEmpty()) {
- branches.Push(Branch{Label(block), nullptr});
- }
- });
+ if (inst->Is<core::ir::If>()) {
+ inst->ForeachBlock([&](core::ir::Block* block) {
+ if (block->IsEmpty()) {
+ branches.Push(Branch{Label(block), nullptr});
+ }
+ });
+ }
OperandList ops{Type(ty), Value(result)};
for (auto& branch : branches) {
diff --git a/src/tint/lang/spirv/writer/raise/shader_io.cc b/src/tint/lang/spirv/writer/raise/shader_io.cc
index ddb311b..9e6bc0c 100644
--- a/src/tint/lang/spirv/writer/raise/shader_io.cc
+++ b/src/tint/lang/spirv/writer/raise/shader_io.cc
@@ -145,9 +145,9 @@
}
/// @copydoc ShaderIO::BackendState::FinalizeOutputs
- core::ir::Value* FinalizeOutputs() override {
+ const core::type::Type* FinalizeOutputs() override {
MakeVars(output_vars, outputs, core::AddressSpace::kOut, core::Access::kWrite, "_Output");
- return nullptr;
+ return ty.void_();
}
/// @copydoc ShaderIO::BackendState::GetInput
diff --git a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index_test.cc b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index_test.cc
index e1a9da3..e774858 100644
--- a/src/tint/lang/spirv/writer/raise/var_for_dynamic_index_test.cc
+++ b/src/tint/lang/spirv/writer/raise/var_for_dynamic_index_test.cc
@@ -566,6 +566,9 @@
func->SetParams({cond, idx_a, idx_b});
b.Append(func->Block(), [&] { //
auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] { //
+ b.NextIteration(loop, b.Splat(arr->Type(), 0_i));
+ });
loop->Body()->SetParams({arr});
b.Append(loop->Body(), [&] {
auto* if_ = b.If(cond);
@@ -583,14 +586,17 @@
auto* src = R"(
%func = func(%2:bool, %3:i32, %4:i32):i32 {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%5:array<i32, 4>): { # body
- if %2 [t: $B3, f: $B4] { # if_1
- $B3: { # true
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration array<i32, 4>(0i) # -> $B3
+ }
+ $B3 (%5:array<i32, 4>): { # body
+ if %2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
%6:i32 = access %5, %3
ret %6
}
- $B4: { # false
+ $B5: { # false
%7:i32 = access %5, %4
ret %7
}
@@ -607,16 +613,19 @@
auto* expect = R"(
%func = func(%2:bool, %3:i32, %4:i32):i32 {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%5:array<i32, 4>): { # body
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration array<i32, 4>(0i) # -> $B3
+ }
+ $B3 (%5:array<i32, 4>): { # body
%6:ptr<function, array<i32, 4>, read_write> = var, %5
- if %2 [t: $B3, f: $B4] { # if_1
- $B3: { # true
+ if %2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
%7:ptr<function, i32, read_write> = access %6, %3
%8:i32 = load %7
ret %8
}
- $B4: { # false
+ $B5: { # false
%9:ptr<function, i32, read_write> = access %6, %4
%10:i32 = load %9
ret %10
@@ -637,7 +646,8 @@
TEST_F(SpirvWriter_VarForDynamicIndexTest,
MultipleAccessesToBlockParam_FromDifferentBlocks_WithLeadingConstantIndex) {
- auto* arr = b.BlockParam(ty.array(ty.array<i32, 4>(), 4));
+ auto* inner_ty = ty.array<i32, 4>();
+ auto* arr = b.BlockParam(ty.array(inner_ty, 4));
auto* cond = b.FunctionParam(ty.bool_());
auto* idx_a = b.FunctionParam(ty.i32());
auto* idx_b = b.FunctionParam(ty.i32());
@@ -645,6 +655,9 @@
func->SetParams({cond, idx_a, idx_b});
b.Append(func->Block(), [&] { //
auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] { //
+ b.NextIteration(loop, b.Splat(arr->Type(), b.Splat(inner_ty, 0_i)));
+ });
loop->Body()->SetParams({arr});
b.Append(loop->Body(), [&] {
auto* if_ = b.If(cond);
@@ -662,14 +675,17 @@
auto* src = R"(
%func = func(%2:bool, %3:i32, %4:i32):i32 {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%5:array<array<i32, 4>, 4>): { # body
- if %2 [t: $B3, f: $B4] { # if_1
- $B3: { # true
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration array<array<i32, 4>, 4>(array<i32, 4>(0i)) # -> $B3
+ }
+ $B3 (%5:array<array<i32, 4>, 4>): { # body
+ if %2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
%6:i32 = access %5, 0u, %3
ret %6
}
- $B4: { # false
+ $B5: { # false
%7:i32 = access %5, 0u, %4
ret %7
}
@@ -686,17 +702,20 @@
auto* expect = R"(
%func = func(%2:bool, %3:i32, %4:i32):i32 {
$B1: {
- loop [b: $B2] { # loop_1
- $B2 (%5:array<array<i32, 4>, 4>): { # body
+ loop [i: $B2, b: $B3] { # loop_1
+ $B2: { # initializer
+ next_iteration array<array<i32, 4>, 4>(array<i32, 4>(0i)) # -> $B3
+ }
+ $B3 (%5:array<array<i32, 4>, 4>): { # body
%6:array<i32, 4> = access %5, 0u
%7:ptr<function, array<i32, 4>, read_write> = var, %6
- if %2 [t: $B3, f: $B4] { # if_1
- $B3: { # true
+ if %2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
%8:ptr<function, i32, read_write> = access %7, %3
%9:i32 = load %8
ret %9
}
- $B4: { # false
+ $B5: { # false
%10:ptr<function, i32, read_write> = access %7, %4
%11:i32 = load %10
ret %11
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index 27fdc1b..febac19 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -302,6 +302,26 @@
EXPECT_INST("%v2 = OpVariable %_ptr_UniformConstant_3_0 UniformConstant");
}
+TEST_F(SpirvWriterTest, Type_StorageTexture_Dedup) {
+ b.Append(b.ir.root_block, [&] {
+ auto* v1 = b.Var("v1", ty.ptr<handle, read_write>(ty.Get<core::type::StorageTexture>(
+ core::type::TextureDimension::k2dArray,
+ core::TexelFormat::kR32Uint, core::Access::kRead, ty.u32())));
+ auto* v2 = b.Var("v2", ty.ptr<handle, read_write>(ty.Get<core::type::StorageTexture>(
+ core::type::TextureDimension::k2dArray,
+ core::TexelFormat::kR32Uint, core::Access::kWrite, ty.u32())));
+ v1->SetBindingPoint(0, 1);
+ v2->SetBindingPoint(0, 2);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%3 = OpTypeImage %uint 2D 0 1 0 2 R32ui");
+ EXPECT_INST("%_ptr_UniformConstant_3 = OpTypePointer UniformConstant %3");
+ EXPECT_INST("%v1 = OpVariable %_ptr_UniformConstant_3 UniformConstant");
+ EXPECT_INST("%_ptr_UniformConstant_3_0 = OpTypePointer UniformConstant %3");
+ EXPECT_INST("%v2 = OpVariable %_ptr_UniformConstant_3_0 UniformConstant");
+}
+
using Dim = core::type::TextureDimension;
struct TextureCase {
std::string result;
diff --git a/src/tint/lang/wgsl/helpers/flatten_bindings.cc b/src/tint/lang/wgsl/helpers/flatten_bindings.cc
index b0565cb..ae136a9 100644
--- a/src/tint/lang/wgsl/helpers/flatten_bindings.cc
+++ b/src/tint/lang/wgsl/helpers/flatten_bindings.cc
@@ -73,6 +73,9 @@
case tint::inspector::ResourceBinding::ResourceType::kExternalTexture:
binding_points.emplace(src, BindingPoint{0, next_texture_idx++});
break;
+ case tint::inspector::ResourceBinding::ResourceType::kInputAttachment:
+ // flattening is not supported for input attachments.
+ TINT_UNREACHABLE();
}
}
}
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index 9857271..8902806 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -42,6 +42,7 @@
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/input_attachment.h"
#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/sampled_texture.h"
@@ -55,6 +56,7 @@
#include "src/tint/lang/wgsl/ast/float_literal_expression.h"
#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "src/tint/lang/wgsl/ast/identifier.h"
+#include "src/tint/lang/wgsl/ast/input_attachment_index_attribute.h"
#include "src/tint/lang/wgsl/ast/int_literal_expression.h"
#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
#include "src/tint/lang/wgsl/ast/location_attribute.h"
@@ -339,6 +341,7 @@
&Inspector::GetDepthTextureResourceBindings,
&Inspector::GetDepthMultisampledTextureResourceBindings,
&Inspector::GetExternalTextureResourceBindings,
+ &Inspector::GetInputAttachmentResourceBindings,
}) {
AppendResourceBindings(&result, (this->*fn)(entry_point));
}
@@ -504,6 +507,45 @@
ResourceBinding::ResourceType::kExternalTexture);
}
+std::vector<ResourceBinding> Inspector::GetInputAttachmentResourceBindings(
+ const std::string& entry_point) {
+ auto* func = FindEntryPointByName(entry_point);
+ if (!func) {
+ return {};
+ }
+
+ std::vector<ResourceBinding> result;
+ auto* func_sem = program_.Sem().Get(func);
+ for (auto& ref : func_sem->TransitivelyReferencedVariablesOfType(
+ &tint::TypeInfo::Of<core::type::InputAttachment>())) {
+ auto* var = ref.first;
+ auto binding_info = ref.second;
+
+ ResourceBinding entry;
+ entry.resource_type = ResourceBinding::ResourceType::kInputAttachment;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
+
+ auto* sem_var = var->As<sem::GlobalVariable>();
+ TINT_ASSERT(sem_var);
+ TINT_ASSERT(sem_var->Attributes().input_attachment_index);
+ entry.input_attachmnt_index = sem_var->Attributes().input_attachment_index.value();
+
+ auto* input_attachment_type = var->Type()->UnwrapRef()->As<core::type::InputAttachment>();
+ auto* base_type = input_attachment_type->type();
+ entry.sampled_kind = BaseTypeToSampledKind(base_type);
+
+ entry.variable_name = var->Declaration()->name->symbol.Name();
+
+ entry.dim =
+ TypeTextureDimensionToResourceBindingTextureDimension(input_attachment_type->dim());
+
+ result.push_back(entry);
+ }
+
+ return result;
+}
+
VectorRef<SamplerTexturePair> Inspector::GetSamplerTextureUses(const std::string& entry_point) {
auto* func = FindEntryPointByName(entry_point);
if (!func) {
diff --git a/src/tint/lang/wgsl/inspector/inspector.h b/src/tint/lang/wgsl/inspector/inspector.h
index 736b2da..e91de8d 100644
--- a/src/tint/lang/wgsl/inspector/inspector.h
+++ b/src/tint/lang/wgsl/inspector/inspector.h
@@ -131,6 +131,13 @@
/// @returns vector of all of the bindings for external textures.
std::vector<ResourceBinding> GetExternalTextureResourceBindings(const std::string& entry_point);
+ /// Gathers all the resource bindings of the input attachment type for the given
+ /// entry point.
+ /// @param entry_point name of the entry point to get information about.
+ /// texture type.
+ /// @returns vector of all of the bindings for input attachments.
+ std::vector<ResourceBinding> GetInputAttachmentResourceBindings(const std::string& entry_point);
+
/// @param entry_point name of the entry point to get information about.
/// @returns vector of all of the sampler/texture sampling pairs that are used
/// by that entry point.
diff --git a/src/tint/lang/wgsl/inspector/inspector_test.cc b/src/tint/lang/wgsl/inspector/inspector_test.cc
index 3dd5661..ca4c4cb 100644
--- a/src/tint/lang/wgsl/inspector/inspector_test.cc
+++ b/src/tint/lang/wgsl/inspector/inspector_test.cc
@@ -2075,6 +2075,66 @@
EXPECT_EQ(3u, result[8].binding);
}
+TEST_F(InspectorGetResourceBindingsTest, InputAttachment) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(1) @input_attachment_index(3)
+ // var input_tex1 : input_attachment<f32>;
+ //
+ // @group(4) @binding(3) @input_attachment_index(1)
+ // var input_tex2 : input_attachment<i32>;
+ //
+ // fn f1() -> vec4f {
+ // return inputAttachmentLoad(input_tex1);
+ // }
+ //
+ // fn f2() -> vec4i {
+ // return inputAttachmentLoad(input_tex2);
+ // }
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ GlobalVar("input_tex1", ty.input_attachment(ty.Of<f32>()),
+ Vector{Group(0_u), Binding(1_u), InputAttachmentIndex(3_u)});
+ GlobalVar("input_tex2", ty.input_attachment(ty.Of<i32>()),
+ Vector{Group(4_u), Binding(3_u), InputAttachmentIndex(1_u)});
+
+ Func("f1", Empty, ty.vec4<f32>(),
+ Vector{
+ Return(Call("inputAttachmentLoad", "input_tex1")),
+ });
+ Func("f2", Empty, ty.vec4<i32>(),
+ Vector{
+ Return(Call("inputAttachmentLoad", "input_tex2")),
+ });
+
+ MakeCallerBodyFunction("main",
+ Vector{
+ std::string("f1"),
+ std::string("f2"),
+ },
+ Vector{
+ Stage(ast::PipelineStage::kFragment),
+ });
+
+ Inspector& inspector = Build();
+
+ auto result = inspector.GetResourceBindings("main");
+ ASSERT_FALSE(inspector.has_error()) << inspector.error();
+ ASSERT_EQ(2u, result.size());
+
+ EXPECT_EQ(ResourceBinding::ResourceType::kInputAttachment, result[0].resource_type);
+ EXPECT_EQ(0u, result[0].bind_group);
+ EXPECT_EQ(1u, result[0].binding);
+ EXPECT_EQ(3u, result[0].input_attachmnt_index);
+ EXPECT_EQ(inspector::ResourceBinding::SampledKind::kFloat, result[0].sampled_kind);
+
+ EXPECT_EQ(ResourceBinding::ResourceType::kInputAttachment, result[1].resource_type);
+ EXPECT_EQ(4u, result[1].bind_group);
+ EXPECT_EQ(3u, result[1].binding);
+ EXPECT_EQ(1u, result[1].input_attachmnt_index);
+ EXPECT_EQ(inspector::ResourceBinding::SampledKind::kSInt, result[1].sampled_kind);
+}
+
TEST_F(InspectorGetUniformBufferResourceBindingsTest, MissingEntryPoint) {
Inspector& inspector = Build();
diff --git a/src/tint/lang/wgsl/inspector/resource_binding.h b/src/tint/lang/wgsl/inspector/resource_binding.h
index 1c1f801..a62819b 100644
--- a/src/tint/lang/wgsl/inspector/resource_binding.h
+++ b/src/tint/lang/wgsl/inspector/resource_binding.h
@@ -98,7 +98,8 @@
kReadWriteStorageTexture,
kDepthTexture,
kDepthMultisampledTexture,
- kExternalTexture
+ kExternalTexture,
+ kInputAttachment,
};
/// Type of resource that is bound.
@@ -107,6 +108,8 @@
uint32_t bind_group;
/// Identifier to identify this binding within the bind group
uint32_t binding;
+ /// Input attachment index. Only available for input attachments.
+ uint32_t input_attachmnt_index;
/// Size for this binding, in bytes, if defined.
uint64_t size;
/// Size for this binding without trailing structure padding, in bytes, if
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 56b70e5..2f14dd4 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -1594,7 +1594,7 @@
// Mark entire expression tree to not const-evaluate
auto r = ast::TraverseExpressions( //
(*binary)->rhs, [&](const ast::Expression* e) {
- skip_const_eval_.Add(e);
+ not_evaluated_.Add(e);
return ast::TraverseAction::Descend;
});
if (!r) {
@@ -1878,7 +1878,7 @@
return expr;
}
- auto* load = b.create<sem::Load>(expr, current_statement_);
+ auto* load = b.create<sem::Load>(expr, current_statement_, expr->Stage());
load->Behaviors() = expr->Behaviors();
b.Sem().Replace(expr->Declaration(), load);
@@ -1909,7 +1909,7 @@
}
const core::constant::Value* materialized_val = nullptr;
- if (!skip_const_eval_.Contains(decl)) {
+ if (!not_evaluated_.Contains(decl)) {
auto expr_val = expr->ConstantValue();
if (TINT_UNLIKELY(!expr_val)) {
ICE(decl->source) << "Materialize(" << decl->TypeInfo().name
@@ -2046,7 +2046,7 @@
const core::constant::Value* val = nullptr;
auto stage = core::EarliestStage(obj->Stage(), idx->Stage());
- if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
stage = core::EvaluationStage::kNotEvaluated;
} else {
if (auto* idx_val = idx->ConstantValue()) {
@@ -2139,7 +2139,7 @@
const core::constant::Value* value = nullptr;
auto stage = core::EarliestStage(overload_stage, args_stage);
- if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
stage = core::EvaluationStage::kNotEvaluated;
}
if (stage == core::EvaluationStage::kConstant) {
@@ -2165,7 +2165,7 @@
const sem::CallTarget* call_target) -> sem::Call* {
auto stage = args_stage; // The evaluation stage of the call
const core::constant::Value* value = nullptr; // The constant value for the call
- if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
stage = core::EvaluationStage::kNotEvaluated;
}
if (stage == core::EvaluationStage::kConstant) {
@@ -2432,7 +2432,7 @@
// now.
const core::constant::Value* value = nullptr;
auto stage = core::EarliestStage(arg_stage, target->Stage());
- if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
stage = core::EvaluationStage::kNotEvaluated;
}
if (stage == core::EvaluationStage::kConstant) {
@@ -3109,8 +3109,8 @@
return nullptr;
}
- auto stage = skip_const_eval_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
- : core::EvaluationStage::kRuntime;
+ auto stage = not_evaluated_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
+ : core::EvaluationStage::kRuntime;
// TODO(crbug.com/tint/1420): For now, assume all function calls have side effects.
bool has_side_effects = true;
@@ -3238,7 +3238,7 @@
const core::constant::Value* val = nullptr;
auto stage = core::EvaluationStage::kConstant;
- if (skip_const_eval_.Contains(literal)) {
+ if (not_evaluated_.Contains(literal)) {
stage = core::EvaluationStage::kNotEvaluated;
}
if (stage == core::EvaluationStage::kConstant) {
@@ -3292,7 +3292,7 @@
auto stage = variable->Stage();
const core::constant::Value* value = variable->ConstantValue();
- if (skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
// This expression is short-circuited by an ancestor expression.
// Do not const-eval.
stage = core::EvaluationStage::kNotEvaluated;
@@ -3632,7 +3632,7 @@
}
const core::constant::Value* value = nullptr;
- if (skip_const_eval_.Contains(expr)) {
+ if (not_evaluated_.Contains(expr)) {
// This expression is short-circuited by an ancestor expression.
// Do not const-eval.
stage = core::EvaluationStage::kNotEvaluated;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index a1abde7..80394c2 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -722,7 +722,7 @@
uint32_t current_scoping_depth_ = 0;
Hashset<TypeAndAddressSpace, 8> valid_type_storage_layouts_;
Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8> logical_binary_lhs_to_parent_;
- Hashset<const ast::Expression*, 8> skip_const_eval_;
+ Hashset<const ast::Expression*, 8> not_evaluated_;
Hashmap<const core::type::Type*, size_t, 8> nest_depth_;
Hashmap<std::pair<core::intrinsic::Overload, wgsl::BuiltinFn>, sem::BuiltinFn*, 64> builtins_;
Hashmap<core::intrinsic::Overload, sem::ValueConstructor*, 16> constructors_;
diff --git a/src/tint/lang/wgsl/sem/load.cc b/src/tint/lang/wgsl/sem/load.cc
index 1f3c18d..636c24e 100644
--- a/src/tint/lang/wgsl/sem/load.cc
+++ b/src/tint/lang/wgsl/sem/load.cc
@@ -33,10 +33,10 @@
TINT_INSTANTIATE_TYPEINFO(tint::sem::Load);
namespace tint::sem {
-Load::Load(const ValueExpression* ref, const Statement* statement)
+Load::Load(const ValueExpression* ref, const Statement* statement, core::EvaluationStage stage)
: Base(/* declaration */ ref->Declaration(),
/* type */ ref->Type()->UnwrapRef(),
- /* stage */ core::EvaluationStage::kRuntime, // Loads can only be runtime
+ /* stage */ stage,
/* statement */ statement,
/* constant */ nullptr,
/* has_side_effects */ ref->HasSideEffects(),
diff --git a/src/tint/lang/wgsl/sem/load.h b/src/tint/lang/wgsl/sem/load.h
index 035682d..1711a25 100644
--- a/src/tint/lang/wgsl/sem/load.h
+++ b/src/tint/lang/wgsl/sem/load.h
@@ -41,7 +41,8 @@
/// Constructor
/// @param reference the reference expression being loaded
/// @param statement the statement that owns this expression
- Load(const ValueExpression* reference, const Statement* statement);
+ /// @param stage the earliest evaluation stage for the expression
+ Load(const ValueExpression* reference, const Statement* statement, core::EvaluationStage stage);
/// Destructor
~Load() override;