Import Tint changes from Dawn
Changes:
- 3241bdcbbf3fd24e09293dbd89cd92950923b428 [tint] Add VectorizeScalarMatrixConstructors by James Price <jrprice@google.com>
- 75aff43b7d0d47b805ef8ca2305de63803a43853 [spirv-writer] Decompose uniform matCx3 types by James Price <jrprice@google.com>
- f27e76501791892092cc47cf709d2fef503769c4 tint/hlsl: workaround DXC bug with const structs/arrays o... by Antonio Maiorano <amaiorano@google.com>
- d6082c5f16c167d27793c693db97cb5b4cd059e4 [tint] Use TINT_ICE_ON_NO_MATCH by Ben Clayton <bclayton@google.com>
- bc87fd0783ba60a25893e4d7a2e8fe35746d009a [tint][utils] Add TINT_ICE_ON_NO_MATCH by Ben Clayton <bclayton@google.com>
- 5b5616f385905b231bc55499699b0881f321e1b4 [tint][glsl] Make the GLSL printer PIMPL by Ben Clayton <bclayton@google.com>
- d8054e28e8977ccb6b0b6fbd7fd8c4b7889ce285 [tint][msl] Make the MSL printer PIMPL by Ben Clayton <bclayton@google.com>
- 89653e0bcb589769d19d4d6cead31c58ee4f1591 [tint] Fix ICE with unreachable after chained ifs by James Price <jrprice@google.com>
- 2258155bf51be618acafd4d15f9b7e77f3c925f9 [tint][spirv] Make the IR printer PIMPL by Ben Clayton <bclayton@google.com>
- 25bfc868283f16b4cd804ac06b171697f853b63c [tint][wgsl] Reject vec4h for @builtin(position) by James Price <jrprice@google.com>
GitOrigin-RevId: 3241bdcbbf3fd24e09293dbd89cd92950923b428
Change-Id: Ie81c840fbc583a8652a802bcdba842d0a647683d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/158360
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 674b1a9..97e62d5 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -2788,11 +2788,7 @@
CreateScalar(source, mgr.types.AInt(), AInt(exp)),
};
},
- [&](Default) {
- TINT_ICE() << "unhandled element type for frexp() const-eval: "
- << s->Type()->FriendlyName();
- return FractExp{error, error};
- });
+ TINT_ICE_ON_NO_MATCH);
};
if (auto* vec = arg->Type()->As<core::type::Vector>()) {
diff --git a/src/tint/lang/core/constant/manager.cc b/src/tint/lang/core/constant/manager.cc
index 0ef1f37..ddd19d8 100644
--- a/src/tint/lang/core/constant/manager.cc
+++ b/src/tint/lang/core/constant/manager.cc
@@ -156,17 +156,14 @@
}
return Composite(s, std::move(zeros));
},
- [&](Default) -> const Value* {
- return Switch(
- type, //
- [&](const core::type::AbstractInt*) { return Get(AInt(0)); }, //
- [&](const core::type::AbstractFloat*) { return Get(AFloat(0)); }, //
- [&](const core::type::I32*) { return Get(i32(0)); }, //
- [&](const core::type::U32*) { return Get(u32(0)); }, //
- [&](const core::type::F32*) { return Get(f32(0)); }, //
- [&](const core::type::F16*) { return Get(f16(0)); }, //
- [&](const core::type::Bool*) { return Get(false); });
- });
+ [&](const core::type::AbstractInt*) { return Get(AInt(0)); }, //
+ [&](const core::type::AbstractFloat*) { return Get(AFloat(0)); }, //
+ [&](const core::type::I32*) { return Get(i32(0)); }, //
+ [&](const core::type::U32*) { return Get(u32(0)); }, //
+ [&](const core::type::F32*) { return Get(f32(0)); }, //
+ [&](const core::type::F16*) { return Get(f16(0)); }, //
+ [&](const core::type::Bool*) { return Get(false); }, //
+ TINT_ICE_ON_NO_MATCH);
}
} // namespace tint::core::constant
diff --git a/src/tint/lang/core/ir/transform/BUILD.bazel b/src/tint/lang/core/ir/transform/BUILD.bazel
index 44a3f6a..c5b1f51 100644
--- a/src/tint/lang/core/ir/transform/BUILD.bazel
+++ b/src/tint/lang/core/ir/transform/BUILD.bazel
@@ -54,6 +54,7 @@
"robustness.cc",
"shader_io.cc",
"std140.cc",
+ "vectorize_scalar_matrix_constructors.cc",
"zero_init_workgroup_memory.cc",
],
hdrs = [
@@ -72,6 +73,7 @@
"robustness.h",
"shader_io.h",
"std140.h",
+ "vectorize_scalar_matrix_constructors.h",
"zero_init_workgroup_memory.h",
],
deps = [
@@ -118,6 +120,7 @@
"preserve_padding_test.cc",
"robustness_test.cc",
"std140_test.cc",
+ "vectorize_scalar_matrix_constructors_test.cc",
"zero_init_workgroup_memory_test.cc",
] + select({
"//conditions:default": [],
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index ed4ae4b..40c0852 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -69,6 +69,8 @@
lang/core/ir/transform/shader_io.h
lang/core/ir/transform/std140.cc
lang/core/ir/transform/std140.h
+ lang/core/ir/transform/vectorize_scalar_matrix_constructors.cc
+ lang/core/ir/transform/vectorize_scalar_matrix_constructors.h
lang/core/ir/transform/zero_init_workgroup_memory.cc
lang/core/ir/transform/zero_init_workgroup_memory.h
)
@@ -116,6 +118,7 @@
lang/core/ir/transform/preserve_padding_test.cc
lang/core/ir/transform/robustness_test.cc
lang/core/ir/transform/std140_test.cc
+ lang/core/ir/transform/vectorize_scalar_matrix_constructors_test.cc
lang/core/ir/transform/zero_init_workgroup_memory_test.cc
)
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index 91305e4..f4fb43e 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -74,6 +74,8 @@
"shader_io.h",
"std140.cc",
"std140.h",
+ "vectorize_scalar_matrix_constructors.cc",
+ "vectorize_scalar_matrix_constructors.h",
"zero_init_workgroup_memory.cc",
"zero_init_workgroup_memory.h",
]
@@ -118,6 +120,7 @@
"preserve_padding_test.cc",
"robustness_test.cc",
"std140_test.cc",
+ "vectorize_scalar_matrix_constructors_test.cc",
"zero_init_workgroup_memory_test.cc",
]
deps = [
diff --git a/src/tint/lang/core/ir/transform/bgra8unorm_polyfill.cc b/src/tint/lang/core/ir/transform/bgra8unorm_polyfill.cc
index 7658d0b..5d81e6b 100644
--- a/src/tint/lang/core/ir/transform/bgra8unorm_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/bgra8unorm_polyfill.cc
@@ -174,9 +174,7 @@
// Just replace arguments to user functions and then stop.
call->SetOperand(use.operand_index, new_value);
},
- [&](Default) {
- TINT_ICE() << "unhandled instruction " << use.instruction->FriendlyName();
- });
+ TINT_ICE_ON_NO_MATCH);
});
}
};
diff --git a/src/tint/lang/core/ir/transform/direct_variable_access.cc b/src/tint/lang/core/ir/transform/direct_variable_access.cc
index 119ee28..a13d1b9 100644
--- a/src/tint/lang/core/ir/transform/direct_variable_access.cc
+++ b/src/tint/lang/core/ir/transform/direct_variable_access.cc
@@ -471,24 +471,16 @@
chain.root_ptr = var->Result();
return nullptr;
},
- [&](Let* let) { return let->Value(); },
- [&](Default) {
- TINT_ICE() << "unhandled instruction type: "
- << (inst ? inst->TypeInfo().name : "<null>");
- return nullptr;
- });
+ [&](Let* let) { return let->Value(); }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](FunctionParam* param) {
// Root pointer is a parameter of the caller
chain.shape.root = RootPtrParameter{param->Type()->As<type::Pointer>()};
chain.root_ptr = param;
return nullptr;
- },
- [&](Default) {
- TINT_ICE() << "unhandled value type: "
- << (value ? value->TypeInfo().name : "<null>");
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
// Reverse the chain's ops and indices. See above for why.
diff --git a/src/tint/lang/core/ir/transform/multiplanar_external_texture.cc b/src/tint/lang/core/ir/transform/multiplanar_external_texture.cc
index f59c2bd..7766c59 100644
--- a/src/tint/lang/core/ir/transform/multiplanar_external_texture.cc
+++ b/src/tint/lang/core/ir/transform/multiplanar_external_texture.cc
@@ -260,10 +260,8 @@
}
}
call->SetOperands(std::move(operands));
- },
- [&](Default) {
- TINT_ICE() << "unhandled instruction " << use.instruction->FriendlyName();
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
});
}
diff --git a/src/tint/lang/core/ir/transform/std140.cc b/src/tint/lang/core/ir/transform/std140.cc
index 9f5dd6b..0f2b8f6 100644
--- a/src/tint/lang/core/ir/transform/std140.cc
+++ b/src/tint/lang/core/ir/transform/std140.cc
@@ -112,7 +112,14 @@
/// @param mat the matrix type to check
/// @returns true if @p mat needs to be decomposed
- static bool NeedsDecomposing(const core::type::Matrix* mat) { return mat->ColumnStride() & 15; }
+ static bool NeedsDecomposing(const core::type::Matrix* mat) {
+ // Std140 layout rules only require us to do this transform for matrices whose column
+ // strides are not a multiple of 16 bytes.
+ //
+ // Due to a bug on Qualcomm devices, we also do this when the *size* of the column vector is
+ // not a multiple of 16 bytes (e.g. matCx3 types). See crbug.com/tint/2074.
+ return mat->ColumnType()->Size() & 15;
+ }
/// Rewrite a type if necessary, decomposing contained matrices.
/// @param type the type to rewrite
diff --git a/src/tint/lang/core/ir/transform/std140_test.cc b/src/tint/lang/core/ir/transform/std140_test.cc
index 3409ad9..8415529 100644
--- a/src/tint/lang/core/ir/transform/std140_test.cc
+++ b/src/tint/lang/core/ir/transform/std140_test.cc
@@ -60,50 +60,6 @@
EXPECT_EQ(expect, str());
}
-TEST_F(IR_Std140Test, NoModify_Mat2x3) {
- auto* mat = ty.mat2x3<f32>();
- auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
- {mod.symbols.New("a"), mat},
- });
- structure->SetStructFlag(core::type::kBlock);
-
- auto* buffer = b.Var("buffer", ty.ptr(uniform, structure));
- buffer->SetBindingPoint(0, 0);
- mod.root_block->Append(buffer);
-
- auto* func = b.Function("foo", mat);
- b.Append(func->Block(), [&] {
- auto* access = b.Access(ty.ptr(uniform, mat), buffer, 0_u);
- auto* load = b.Load(access);
- b.Return(func, load);
- });
-
- auto* src = R"(
-MyStruct = struct @align(16), @block {
- a:mat2x3<f32> @offset(0)
-}
-
-%b1 = block { # root
- %buffer:ptr<uniform, MyStruct, read_write> = var @binding_point(0, 0)
-}
-
-%foo = func():mat2x3<f32> -> %b2 {
- %b2 = block {
- %3:ptr<uniform, mat2x3<f32>, read_write> = access %buffer, 0u
- %4:mat2x3<f32> = load %3
- ret %4
- }
-}
-)";
- EXPECT_EQ(src, str());
-
- auto* expect = src;
-
- Run(Std140);
-
- EXPECT_EQ(expect, str());
-}
-
TEST_F(IR_Std140Test, NoModify_Mat2x4) {
auto* mat = ty.mat2x4<f32>();
auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
@@ -1453,6 +1409,80 @@
EXPECT_EQ(expect, str());
}
+TEST_F(IR_Std140Test, Mat4x3_LoadMatrix) {
+ auto* mat = ty.mat4x3<f32>();
+ auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
+ {mod.symbols.New("a"), mat},
+ });
+ structure->SetStructFlag(core::type::kBlock);
+
+ auto* buffer = b.Var("buffer", ty.ptr(uniform, structure));
+ buffer->SetBindingPoint(0, 0);
+ mod.root_block->Append(buffer);
+
+ auto* func = b.Function("foo", mat);
+ b.Append(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr(uniform, mat), buffer, 0_u);
+ auto* load = b.Load(access);
+ b.Return(func, load);
+ });
+
+ auto* src = R"(
+MyStruct = struct @align(16), @block {
+ a:mat4x3<f32> @offset(0)
+}
+
+%b1 = block { # root
+ %buffer:ptr<uniform, MyStruct, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():mat4x3<f32> -> %b2 {
+ %b2 = block {
+ %3:ptr<uniform, mat4x3<f32>, read_write> = access %buffer, 0u
+ %4:mat4x3<f32> = load %3
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+MyStruct = struct @align(16), @block {
+ a:mat4x3<f32> @offset(0)
+}
+
+MyStruct_std140 = struct @align(16), @block {
+ a_col0:vec3<f32> @offset(0)
+ a_col1:vec3<f32> @offset(16)
+ a_col2:vec3<f32> @offset(32)
+ a_col3:vec3<f32> @offset(48)
+}
+
+%b1 = block { # root
+ %buffer:ptr<uniform, MyStruct_std140, read_write> = var @binding_point(0, 0)
+}
+
+%foo = func():mat4x3<f32> -> %b2 {
+ %b2 = block {
+ %3:ptr<uniform, vec3<f32>, read_write> = access %buffer, 0u
+ %4:vec3<f32> = load %3
+ %5:ptr<uniform, vec3<f32>, read_write> = access %buffer, 1u
+ %6:vec3<f32> = load %5
+ %7:ptr<uniform, vec3<f32>, read_write> = access %buffer, 2u
+ %8:vec3<f32> = load %7
+ %9:ptr<uniform, vec3<f32>, read_write> = access %buffer, 3u
+ %10:vec3<f32> = load %9
+ %11:mat4x3<f32> = construct %4, %6, %8, %10
+ ret %11
+ }
+}
+)";
+
+ Run(Std140);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(IR_Std140Test, F16) {
auto* structure =
ty.Struct(mod.symbols.New("MyStruct"), {
diff --git a/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.cc b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.cc
new file mode 100644
index 0000000..68e365a
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.cc
@@ -0,0 +1,109 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+ /// The IR module.
+ Module& ir;
+
+ /// The IR builder.
+ Builder b{ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir.Types()};
+
+ /// Process the module.
+ void Process() {
+ // Find and replace matrix constructors that take scalar operands.
+ Vector<Construct*, 8> worklist;
+ for (auto inst : ir.instructions.Objects()) {
+ if (auto* construct = inst->As<Construct>(); construct && construct->Alive()) {
+ if (construct->Result()->Type()->As<type::Matrix>()) {
+ if (construct->Operands().Length() > 0 &&
+ construct->Operands()[0]->Type()->Is<type::Scalar>()) {
+ b.InsertBefore(construct, [&] { //
+ ReplaceConstructor(construct);
+ });
+ }
+ }
+ }
+ }
+ }
+
+ /// Replace a matrix construct instruction.
+ /// @param construct the instruction to replace
+ void ReplaceConstructor(Construct* construct) {
+ auto* mat = construct->Result()->Type()->As<type::Matrix>();
+ auto* col = mat->ColumnType();
+ const auto& scalars = construct->Operands();
+
+ // Collect consecutive scalars into column vectors.
+ Vector<Value*, 4> columns;
+ for (uint32_t c = 0; c < mat->columns(); c++) {
+ Vector<Value*, 4> values;
+ for (uint32_t r = 0; r < col->Width(); r++) {
+ values.Push(scalars[c * col->Width() + r]);
+ }
+ columns.Push(b.Construct(col, std::move(values))->Result());
+ }
+
+ // Construct the matrix from the column vectors and replace the original instruction.
+ auto* replacement = b.Construct(mat, std::move(columns))->Result();
+ construct->Result()->ReplaceAllUsesWith(replacement);
+ construct->Destroy();
+ }
+};
+
+} // namespace
+
+Result<SuccessType> VectorizeScalarMatrixConstructors(Module& ir) {
+ auto result = ValidateAndDumpIfNeeded(ir, "VectorizeScalarMatrixConstructors transform");
+ if (!result) {
+ return result;
+ }
+
+ State{ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h
new file mode 100644
index 0000000..ff4ef93
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h
@@ -0,0 +1,49 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// VectorizeScalarMatrixConstructors is a transform that replaces construct instructions that
+/// produce matrices from scalar operands to construct individual columns first.
+///
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> VectorizeScalarMatrixConstructors(Module& module);
+
+} // namespace tint::core::ir::transform
+
+#endif // SRC_TINT_LANG_CORE_IR_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_
diff --git a/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors_test.cc b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors_test.cc
new file mode 100644
index 0000000..4f42888
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors_test.cc
@@ -0,0 +1,576 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/matrix.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using IR_VectorizeScalarMatrixConstructorsTest = TransformTest;
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, NoModify_NoOperands) {
+ auto* mat = ty.mat3x3<f32>();
+ auto* func = b.Function("foo", mat);
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func():mat3x3<f32> -> %b1 {
+ %b1 = block {
+ %2:mat3x3<f32> = construct
+ ret %2
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, NoModify_Identity) {
+ auto* mat = ty.mat3x3<f32>();
+ auto* value = b.FunctionParam("value", mat);
+ auto* func = b.Function("foo", mat);
+ func->SetParams({value});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, value);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%value:mat3x3<f32>):mat3x3<f32> -> %b1 {
+ %b1 = block {
+ %3:mat3x3<f32> = construct %value
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, NoModify_Vectors) {
+ auto* mat = ty.mat3x3<f32>();
+ auto* v1 = b.FunctionParam("v1", mat->ColumnType());
+ auto* v2 = b.FunctionParam("v2", mat->ColumnType());
+ auto* v3 = b.FunctionParam("v3", mat->ColumnType());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:vec3<f32>, %v2:vec3<f32>, %v3:vec3<f32>):mat3x3<f32> -> %b1 {
+ %b1 = block {
+ %5:mat3x3<f32> = construct %v1, %v2, %v3
+ ret %5
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat2x2) {
+ auto* mat = ty.mat2x2<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32):mat2x2<f32> -> %b1 {
+ %b1 = block {
+ %6:mat2x2<f32> = construct %v1, %v2, %v3, %v4
+ ret %6
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32):mat2x2<f32> -> %b1 {
+ %b1 = block {
+ %6:vec2<f32> = construct %v1, %v2
+ %7:vec2<f32> = construct %v3, %v4
+ %8:mat2x2<f32> = construct %6, %7
+ ret %8
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat2x3) {
+ auto* mat = ty.mat2x3<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %8:mat2x3<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6
+ ret %8
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32):mat2x3<f32> -> %b1 {
+ %b1 = block {
+ %8:vec3<f32> = construct %v1, %v2, %v3
+ %9:vec3<f32> = construct %v4, %v5, %v6
+ %10:mat2x3<f32> = construct %8, %9
+ ret %10
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat2x4) {
+ auto* mat = ty.mat2x4<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32):mat2x4<f32> -> %b1 {
+ %b1 = block {
+ %10:mat2x4<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8
+ ret %10
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32):mat2x4<f32> -> %b1 {
+ %b1 = block {
+ %10:vec4<f32> = construct %v1, %v2, %v3, %v4
+ %11:vec4<f32> = construct %v5, %v6, %v7, %v8
+ %12:mat2x4<f32> = construct %10, %11
+ ret %12
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat3x2) {
+ auto* mat = ty.mat3x2<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32):mat3x2<f32> -> %b1 {
+ %b1 = block {
+ %8:mat3x2<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6
+ ret %8
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32):mat3x2<f32> -> %b1 {
+ %b1 = block {
+ %8:vec2<f32> = construct %v1, %v2
+ %9:vec2<f32> = construct %v3, %v4
+ %10:vec2<f32> = construct %v5, %v6
+ %11:mat3x2<f32> = construct %8, %9, %10
+ ret %11
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat3x3) {
+ auto* mat = ty.mat3x3<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* v9 = b.FunctionParam("v9", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8, v9});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8, v9);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32):mat3x3<f32> -> %b1 {
+ %b1 = block {
+ %11:mat3x3<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8, %v9
+ ret %11
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32):mat3x3<f32> -> %b1 {
+ %b1 = block {
+ %11:vec3<f32> = construct %v1, %v2, %v3
+ %12:vec3<f32> = construct %v4, %v5, %v6
+ %13:vec3<f32> = construct %v7, %v8, %v9
+ %14:mat3x3<f32> = construct %11, %12, %13
+ ret %14
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat3x4) {
+ auto* mat = ty.mat3x4<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* v9 = b.FunctionParam("v9", ty.f32());
+ auto* v10 = b.FunctionParam("v10", ty.f32());
+ auto* v11 = b.FunctionParam("v11", ty.f32());
+ auto* v12 = b.FunctionParam("v12", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32):mat3x4<f32> -> %b1 {
+ %b1 = block {
+ %14:mat3x4<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8, %v9, %v10, %v11, %v12
+ ret %14
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32):mat3x4<f32> -> %b1 {
+ %b1 = block {
+ %14:vec4<f32> = construct %v1, %v2, %v3, %v4
+ %15:vec4<f32> = construct %v5, %v6, %v7, %v8
+ %16:vec4<f32> = construct %v9, %v10, %v11, %v12
+ %17:mat3x4<f32> = construct %14, %15, %16
+ ret %17
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat4x2) {
+ auto* mat = ty.mat4x2<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32):mat4x2<f32> -> %b1 {
+ %b1 = block {
+ %10:mat4x2<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8
+ ret %10
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32):mat4x2<f32> -> %b1 {
+ %b1 = block {
+ %10:vec2<f32> = construct %v1, %v2
+ %11:vec2<f32> = construct %v3, %v4
+ %12:vec2<f32> = construct %v5, %v6
+ %13:vec2<f32> = construct %v7, %v8
+ %14:mat4x2<f32> = construct %10, %11, %12, %13
+ ret %14
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat4x3) {
+ auto* mat = ty.mat4x3<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* v9 = b.FunctionParam("v9", ty.f32());
+ auto* v10 = b.FunctionParam("v10", ty.f32());
+ auto* v11 = b.FunctionParam("v11", ty.f32());
+ auto* v12 = b.FunctionParam("v12", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32):mat4x3<f32> -> %b1 {
+ %b1 = block {
+ %14:mat4x3<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8, %v9, %v10, %v11, %v12
+ ret %14
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32):mat4x3<f32> -> %b1 {
+ %b1 = block {
+ %14:vec3<f32> = construct %v1, %v2, %v3
+ %15:vec3<f32> = construct %v4, %v5, %v6
+ %16:vec3<f32> = construct %v7, %v8, %v9
+ %17:vec3<f32> = construct %v10, %v11, %v12
+ %18:mat4x3<f32> = construct %14, %15, %16, %17
+ ret %18
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat4x4) {
+ auto* mat = ty.mat4x4<f32>();
+ auto* v1 = b.FunctionParam("v1", ty.f32());
+ auto* v2 = b.FunctionParam("v2", ty.f32());
+ auto* v3 = b.FunctionParam("v3", ty.f32());
+ auto* v4 = b.FunctionParam("v4", ty.f32());
+ auto* v5 = b.FunctionParam("v5", ty.f32());
+ auto* v6 = b.FunctionParam("v6", ty.f32());
+ auto* v7 = b.FunctionParam("v7", ty.f32());
+ auto* v8 = b.FunctionParam("v8", ty.f32());
+ auto* v9 = b.FunctionParam("v9", ty.f32());
+ auto* v10 = b.FunctionParam("v10", ty.f32());
+ auto* v11 = b.FunctionParam("v11", ty.f32());
+ auto* v12 = b.FunctionParam("v12", ty.f32());
+ auto* v13 = b.FunctionParam("v13", ty.f32());
+ auto* v14 = b.FunctionParam("v14", ty.f32());
+ auto* v15 = b.FunctionParam("v15", ty.f32());
+ auto* v16 = b.FunctionParam("v16", ty.f32());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16});
+ b.Append(func->Block(), [&] {
+ auto* construct =
+ b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32, %v13:f32, %v14:f32, %v15:f32, %v16:f32):mat4x4<f32> -> %b1 {
+ %b1 = block {
+ %18:mat4x4<f32> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8, %v9, %v10, %v11, %v12, %v13, %v14, %v15, %v16
+ ret %18
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f32, %v2:f32, %v3:f32, %v4:f32, %v5:f32, %v6:f32, %v7:f32, %v8:f32, %v9:f32, %v10:f32, %v11:f32, %v12:f32, %v13:f32, %v14:f32, %v15:f32, %v16:f32):mat4x4<f32> -> %b1 {
+ %b1 = block {
+ %18:vec4<f32> = construct %v1, %v2, %v3, %v4
+ %19:vec4<f32> = construct %v5, %v6, %v7, %v8
+ %20:vec4<f32> = construct %v9, %v10, %v11, %v12
+ %21:vec4<f32> = construct %v13, %v14, %v15, %v16
+ %22:mat4x4<f32> = construct %18, %19, %20, %21
+ ret %22
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_VectorizeScalarMatrixConstructorsTest, Mat3x3_F16) {
+ auto* mat = ty.mat3x3<f16>();
+ auto* v1 = b.FunctionParam("v1", ty.f16());
+ auto* v2 = b.FunctionParam("v2", ty.f16());
+ auto* v3 = b.FunctionParam("v3", ty.f16());
+ auto* v4 = b.FunctionParam("v4", ty.f16());
+ auto* v5 = b.FunctionParam("v5", ty.f16());
+ auto* v6 = b.FunctionParam("v6", ty.f16());
+ auto* v7 = b.FunctionParam("v7", ty.f16());
+ auto* v8 = b.FunctionParam("v8", ty.f16());
+ auto* v9 = b.FunctionParam("v9", ty.f16());
+ auto* func = b.Function("foo", mat);
+ func->SetParams({v1, v2, v3, v4, v5, v6, v7, v8, v9});
+ b.Append(func->Block(), [&] {
+ auto* construct = b.Construct(mat, v1, v2, v3, v4, v5, v6, v7, v8, v9);
+ b.Return(func, construct->Result());
+ });
+
+ auto* src = R"(
+%foo = func(%v1:f16, %v2:f16, %v3:f16, %v4:f16, %v5:f16, %v6:f16, %v7:f16, %v8:f16, %v9:f16):mat3x3<f16> -> %b1 {
+ %b1 = block {
+ %11:mat3x3<f16> = construct %v1, %v2, %v3, %v4, %v5, %v6, %v7, %v8, %v9
+ ret %11
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%v1:f16, %v2:f16, %v3:f16, %v4:f16, %v5:f16, %v6:f16, %v7:f16, %v8:f16, %v9:f16):mat3x3<f16> -> %b1 {
+ %b1 = block {
+ %11:vec3<f16> = construct %v1, %v2, %v3
+ %12:vec3<f16> = construct %v4, %v5, %v6
+ %13:vec3<f16> = construct %v7, %v8, %v9
+ %14:mat3x3<f16> = construct %11, %12, %13
+ ret %14
+ }
+}
+)";
+
+ Run(VectorizeScalarMatrixConstructors);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
index 0393fc8..927844b 100644
--- a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
@@ -262,8 +262,8 @@
new_indices.Push(member->Index());
PrepareStores(var, member->Type(), iteration_count, new_indices, stores);
}
- },
- [&](Default) { TINT_UNREACHABLE(); });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Get or inject an entry point builtin for the local invocation index.
diff --git a/src/tint/lang/core/type/builtin_structs.cc b/src/tint/lang/core/type/builtin_structs.cc
index 09c4587..8f319d5 100644
--- a/src/tint/lang/core/type/builtin_structs.cc
+++ b/src/tint/lang/core/type/builtin_structs.cc
@@ -98,16 +98,10 @@
build(kModfVecF16Names[width - 2], types.vec(types.f16(), width)),
});
return abstract;
- },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled modf type";
- return nullptr;
- });
- },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled modf type";
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// An array of `frexp()` return type names for an argument of `vecN<f32>`.
@@ -170,16 +164,10 @@
build(kFrexpVecF16Names[width - 2], vec_f16, vec_i32),
});
return abstract;
- },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled frexp type";
- return nullptr;
- });
- },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled frexp type";
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
Struct* CreateAtomicCompareExchangeResult(Manager& types, SymbolTable& symbols, const Type* ty) {
@@ -194,10 +182,7 @@
ty, //
[&](const I32*) { return build(core::BuiltinType::kAtomicCompareExchangeResultI32); },
[&](const U32*) { return build(core::BuiltinType::kAtomicCompareExchangeResultU32); },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled atomic_compare_exchange type";
- return nullptr;
- });
+ TINT_ICE_ON_NO_MATCH);
}
} // namespace tint::core::type
diff --git a/src/tint/lang/core/type/manager.h b/src/tint/lang/core/type/manager.h
index 687f465..6d204c4 100644
--- a/src/tint/lang/core/type/manager.h
+++ b/src/tint/lang/core/type/manager.h
@@ -129,6 +129,8 @@
return ptr<T::address, typename T::type, T::access>(std::forward<ARGS>(args)...);
} else if constexpr (core::fluent_types::IsArray<T>) {
return array<typename T::type, T::length>(std::forward<ARGS>(args)...);
+ } else if constexpr (core::fluent_types::IsAtomic<T>) {
+ return atomic<typename T::type>(std::forward<ARGS>(args)...);
} else if constexpr (tint::traits::IsTypeOrDerived<T, Type>) {
return types_.Get<T>(std::forward<ARGS>(args)...);
} else if constexpr (tint::traits::IsTypeOrDerived<T, UniqueNode>) {
diff --git a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
index fc9a7ed..cec5d7e 100644
--- a/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/glsl/writer/ast_printer/ast_printer.cc
@@ -318,10 +318,8 @@
[&](const ast::Enable* enable) {
// Record the required extension for generating extension directive later
RecordExtension(enable);
- },
- [&](Default) {
- TINT_ICE() << "unhandled module-scope declaration: " << decl->TypeInfo().name;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
TextBuffer extensions;
@@ -806,9 +804,7 @@
[&](const sem::BuiltinFn* builtin) { EmitBuiltinCall(out, call, builtin); },
[&](const sem::ValueConversion* conv) { EmitValueConversion(out, call, conv); },
[&](const sem::ValueConstructor* ctor) { EmitValueConstructor(out, call, ctor); },
- [&](Default) {
- TINT_ICE() << "unhandled call target: " << call->Target()->TypeInfo().name;
- });
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitFunctionCall(StringStream& out,
@@ -1795,11 +1791,8 @@
[&](const ast::IdentifierExpression* i) { EmitIdentifier(out, i); },
[&](const ast::LiteralExpression* l) { EmitLiteral(out, l); },
[&](const ast::MemberAccessorExpression* m) { EmitMemberAccessor(out, m); },
- [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(out, u); },
- [&](Default) { //
- diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
- std::string(expr->TypeInfo().name));
- });
+ [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(out, u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitIdentifier(StringStream& out, const ast::IdentifierExpression* expr) {
@@ -1922,10 +1915,8 @@
},
[&](const ast::Const*) {
// Constants are embedded at their use
- },
- [&](Default) {
- TINT_ICE() << "unhandled global variable type " << global->TypeInfo().name;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
@@ -2307,11 +2298,8 @@
}
EmitConstant(out, constant->Index(i));
}
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unhandled constant type: " + constant->Type()->FriendlyName());
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitLiteral(StringStream& out, const ast::LiteralExpression* lit) {
@@ -2338,8 +2326,8 @@
}
}
diagnostics_.add_error(diag::System::Writer, "unknown integer literal suffix type");
- },
- [&](Default) { diagnostics_.add_error(diag::System::Writer, "unknown literal type"); });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitZeroValue(StringStream& out, const core::type::Type* type) {
@@ -2571,8 +2559,8 @@
},
[&](const sem::StructMemberAccess* member_access) {
out << member_access->Member()->Name().Name();
- },
- [&](Default) { TINT_ICE() << "unknown member access type: " << sem->TypeInfo().name; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitReturn(const ast::ReturnStatement* stmt) {
@@ -2613,18 +2601,13 @@
[&](const ast::Let* let) { EmitLet(let); },
[&](const ast::Const*) {
// Constants are embedded at their use
- },
- [&](Default) { //
- TINT_ICE() << "unknown variable type: " << v->variable->TypeInfo().name;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](const ast::ConstAssert*) {
// Not emitted
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitSwitch(const ast::SwitchStatement* stmt) {
diff --git a/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc b/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
index 644a0a5..67eb338 100644
--- a/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
+++ b/src/tint/lang/glsl/writer/ast_raise/texture_builtins_from_uniform.cc
@@ -160,10 +160,8 @@
GetAndRecordFunctionParameter(fn, variable, dataType);
// Record the call and new_param to be replaced later.
builtin_to_replace.Add(call_expr, new_param);
- },
- [&](Default) {
- TINT_ICE() << "unexpected texture root identifier";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](const sem::Function* user_fn) {
auto user_param_to_info = fn_to_data.Find(user_fn);
@@ -200,10 +198,8 @@
fn, variable, info->field);
// Record adding extra function parameter
args.Push(new_param);
- },
- [&](Default) {
- TINT_ICE() << "unexpected texture root identifier";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
}
});
diff --git a/src/tint/lang/glsl/writer/printer/BUILD.bazel b/src/tint/lang/glsl/writer/printer/BUILD.bazel
index f54faec..8961c7a 100644
--- a/src/tint/lang/glsl/writer/printer/BUILD.bazel
+++ b/src/tint/lang/glsl/writer/printer/BUILD.bazel
@@ -91,7 +91,6 @@
"//src/tint/lang/glsl/writer/raise",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
- "//src/tint/utils/generator",
"//src/tint/utils/ice",
"//src/tint/utils/id",
"//src/tint/utils/macros",
diff --git a/src/tint/lang/glsl/writer/printer/BUILD.cmake b/src/tint/lang/glsl/writer/printer/BUILD.cmake
index 73f0081..31486dd 100644
--- a/src/tint/lang/glsl/writer/printer/BUILD.cmake
+++ b/src/tint/lang/glsl/writer/printer/BUILD.cmake
@@ -96,7 +96,6 @@
tint_lang_glsl_writer_raise
tint_utils_containers
tint_utils_diagnostic
- tint_utils_generator
tint_utils_ice
tint_utils_id
tint_utils_macros
diff --git a/src/tint/lang/glsl/writer/printer/BUILD.gn b/src/tint/lang/glsl/writer/printer/BUILD.gn
index 20efd60..59f838a 100644
--- a/src/tint/lang/glsl/writer/printer/BUILD.gn
+++ b/src/tint/lang/glsl/writer/printer/BUILD.gn
@@ -93,7 +93,6 @@
"${tint_src_dir}/lang/glsl/writer/raise",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
- "${tint_src_dir}/utils/generator",
"${tint_src_dir}/utils/ice",
"${tint_src_dir}/utils/id",
"${tint_src_dir}/utils/macros",
diff --git a/src/tint/lang/glsl/writer/printer/helper_test.h b/src/tint/lang/glsl/writer/printer/helper_test.h
index 709fb17..8a420e6 100644
--- a/src/tint/lang/glsl/writer/printer/helper_test.h
+++ b/src/tint/lang/glsl/writer/printer/helper_test.h
@@ -33,6 +33,7 @@
#include "gtest/gtest.h"
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/glsl/writer/common/version.h"
#include "src/tint/lang/glsl/writer/printer/printer.h"
#include "src/tint/lang/glsl/writer/raise/raise.h"
@@ -42,8 +43,6 @@
template <typename BASE>
class GlslPrinterTestHelperBase : public BASE {
public:
- GlslPrinterTestHelperBase() : writer_(mod) {}
-
/// The test module.
core::ir::Module mod;
/// The test builder.
@@ -54,9 +53,6 @@
Version version{};
protected:
- /// The GLSL writer.
- Printer writer_;
-
/// Validation errors
std::string err_;
@@ -66,18 +62,17 @@
/// Run the writer on the IR module and validate the result.
/// @returns true if generation and validation succeeded
bool Generate() {
- auto raised = raise::Raise(mod);
- if (!raised) {
+ if (auto raised = raise::Raise(mod); !raised) {
err_ = raised.Failure().reason.str();
return false;
}
- auto result = writer_.Generate(version);
+ auto result = Print(mod, version);
if (!result) {
err_ = result.Failure().reason.str();
return false;
}
- output_ = writer_.Result();
+ output_ = result.Get();
return true;
}
diff --git a/src/tint/lang/glsl/writer/printer/printer.cc b/src/tint/lang/glsl/writer/printer/printer.cc
index 6351b29..f3dda8c 100644
--- a/src/tint/lang/glsl/writer/printer/printer.cc
+++ b/src/tint/lang/glsl/writer/printer/printer.cc
@@ -27,123 +27,150 @@
#include "src/tint/lang/glsl/writer/printer/printer.h"
+#include <string>
#include <utility>
+#include "src/tint/lang/core/ir/function.h"
+#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/unreachable.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/glsl/writer/common/version.h"
+#include "src/tint/utils/generator/text_generator.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/rtti/switch.h"
using namespace tint::core::fluent_types; // NOLINT
namespace tint::glsl::writer {
+namespace {
-// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
-#define UNHANDLED_CASE(object_ptr) \
- TINT_UNIMPLEMENTED() << "unhandled case in Switch(): " \
- << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
+/// PIMPL class for the MSL generator
+class Printer : public tint::TextGenerator {
+ public:
+ /// Constructor
+ /// @param module the Tint IR module to generate
+ explicit Printer(core::ir::Module& module) : ir_(module) {}
-Printer::Printer(core::ir::Module& module) : ir_(module) {}
+ /// @param version the GLSL version information
+ /// @returns the generated GLSL shader
+ tint::Result<std::string> Generate(const Version& version) {
+ auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "GLSL writer");
+ if (!valid) {
+ return std::move(valid.Failure());
+ }
-Printer::~Printer() = default;
+ {
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
-tint::Result<SuccessType> Printer::Generate(Version version) {
- auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "GLSL writer");
- if (!valid) {
- return std::move(valid.Failure());
+ auto out = Line();
+ out << "#version " << version.major_version << version.minor_version << "0";
+ if (version.IsES()) {
+ out << " es";
+ }
+ }
+
+ // Emit module-scope declarations.
+ EmitBlockInstructions(ir_.root_block);
+
+ // Emit functions.
+ for (auto* func : ir_.functions) {
+ EmitFunction(func);
+ }
+
+ StringStream ss;
+ ss << preamble_buffer_.String() << '\n' << main_buffer_.String();
+ return ss.str();
}
- {
- TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
+ private:
+ core::ir::Module& ir_;
- auto out = Line();
- out << "#version " << version.major_version << version.minor_version << "0";
- if (version.IsES()) {
- out << " es";
+ /// The buffer holding preamble text
+ TextBuffer preamble_buffer_;
+
+ /// The current function being emitted
+ core::ir::Function* current_function_ = nullptr;
+ /// The current block being emitted
+ core::ir::Block* current_block_ = nullptr;
+
+ /// Emit the function
+ /// @param func the function to emit
+ void EmitFunction(core::ir::Function* func) {
+ TINT_SCOPED_ASSIGNMENT(current_function_, func);
+
+ {
+ auto out = Line();
+
+ // TODO(dsinclair): Emit function stage if any
+ // TODO(dsinclair): Handle return type attributes
+
+ EmitType(out, func->ReturnType());
+ out << " " << ir_.NameOf(func).Name() << "() {";
+
+ // TODO(dsinclair): Emit Function parameters
+ }
+ {
+ ScopedIndent si(current_buffer_);
+ EmitBlock(func->Block());
+ }
+
+ Line() << "}";
+ }
+
+ /// Emit a block
+ /// @param block the block to emit
+ void EmitBlock(core::ir::Block* block) {
+ // TODO(dsinclair): Handle marking inline
+ // MarkInlinable(block);
+
+ EmitBlockInstructions(block);
+ }
+
+ /// Emit the instructions in a block
+ /// @param block the block with the instructions to emit
+ void EmitBlockInstructions(core::ir::Block* block) {
+ TINT_SCOPED_ASSIGNMENT(current_block_, block);
+
+ for (auto* inst : *block) {
+ Switch(
+ inst, //
+ [&](core::ir::Return* r) { EmitReturn(r); }, //
+ [&](core::ir::Unreachable*) { EmitUnreachable(); }, //
+ TINT_ICE_ON_NO_MATCH);
}
}
- // Emit module-scope declarations.
- EmitBlockInstructions(ir_.root_block);
+ /// Emit a type
+ /// @param out the stream to emit too
+ /// @param ty the type to emit
+ void EmitType(StringStream& out, [[maybe_unused]] const core::type::Type* ty) { out << "void"; }
- // Emit functions.
- for (auto* func : ir_.functions) {
- EmitFunction(func);
- }
+ /// Emit a return instruction
+ /// @param r the return instruction
+ void EmitReturn(core::ir::Return* r) {
+ // If this return has no arguments and the current block is for the function which is
+ // being returned, skip the return.
+ if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
+ return;
+ }
- return Success;
-}
-
-std::string Printer::Result() const {
- StringStream ss;
- ss << preamble_buffer_.String() << '\n' << main_buffer_.String();
- return ss.str();
-}
-
-void Printer::EmitFunction(core::ir::Function* func) {
- TINT_SCOPED_ASSIGNMENT(current_function_, func);
-
- {
auto out = Line();
-
- // TODO(dsinclair): Emit function stage if any
- // TODO(dsinclair): Handle return type attributes
-
- EmitType(out, func->ReturnType());
- out << " " << ir_.NameOf(func).Name() << "() {";
-
- // TODO(dsinclair): Emit Function parameters
- }
- {
- ScopedIndent si(current_buffer_);
- EmitBlock(func->Block());
+ out << "return";
+ // TODO(dsinclair): Handle return args
+ // if (!r->Args().IsEmpty()) {
+ // out << " " << Expr(r->Args().Front());
+ // }
+ out << ";";
}
- Line() << "}";
-}
+ /// Emit an unreachable instruction
+ void EmitUnreachable() { Line() << "/* unreachable */"; }
+};
+} // namespace
-void Printer::EmitBlock(core::ir::Block* block) {
- // TODO(dsinclair): Handle marking inline
- // MarkInlinable(block);
-
- EmitBlockInstructions(block);
-}
-
-void Printer::EmitBlockInstructions(core::ir::Block* block) {
- TINT_SCOPED_ASSIGNMENT(current_block_, block);
-
- for (auto* inst : *block) {
- Switch(
- inst, //
- [&](core::ir::Return* r) { EmitReturn(r); }, //
- [&](core::ir::Unreachable*) { EmitUnreachable(); }, //
- [&](Default) { UNHANDLED_CASE(inst); });
- }
-}
-
-void Printer::EmitType(StringStream& out, [[maybe_unused]] const core::type::Type* ty) {
- out << "void";
-}
-
-void Printer::EmitReturn(core::ir::Return* r) {
- // If this return has no arguments and the current block is for the function which is
- // being returned, skip the return.
- if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
- return;
- }
-
- auto out = Line();
- out << "return";
- // TODO(dsinclair): Handle return args
- // if (!r->Args().IsEmpty()) {
- // out << " " << Expr(r->Args().Front());
- // }
- out << ";";
-}
-
-void Printer::EmitUnreachable() {
- Line() << "/* unreachable */";
+Result<std::string> Print(core::ir::Module& module, const Version& version) {
+ return Printer{module}.Generate(version);
}
} // namespace tint::glsl::writer
diff --git a/src/tint/lang/glsl/writer/printer/printer.h b/src/tint/lang/glsl/writer/printer/printer.h
index 5df46df..49c6301 100644
--- a/src/tint/lang/glsl/writer/printer/printer.h
+++ b/src/tint/lang/glsl/writer/printer/printer.h
@@ -30,72 +30,22 @@
#include <string>
-#include "src/tint/lang/core/ir/module.h"
-#include "src/tint/lang/glsl/writer/common/version.h"
-#include "src/tint/utils/generator/text_generator.h"
+#include "src/tint/utils/result/result.h"
// Forward declarations
namespace tint::core::ir {
-class Binary;
-class ExitIf;
-class If;
-class Let;
-class Load;
-class Return;
-class Unreachable;
-class Var;
+class Module;
} // namespace tint::core::ir
+namespace tint::glsl::writer {
+struct Version;
+} // namespace tint::glsl::writer
namespace tint::glsl::writer {
-/// Implementation class for the MSL generator
-class Printer : public tint::TextGenerator {
- public:
- /// Constructor
- /// @param module the Tint IR module to generate
- explicit Printer(core::ir::Module& module);
- ~Printer() override;
-
- /// @param version the GLSL version information
- /// @returns success or failure
- tint::Result<SuccessType> Generate(Version version);
-
- /// @copydoc tint::TextGenerator::Result
- std::string Result() const override;
-
- private:
- /// Emit the function
- /// @param func the function to emit
- void EmitFunction(core::ir::Function* func);
-
- /// Emit a block
- /// @param block the block to emit
- void EmitBlock(core::ir::Block* block);
- /// Emit the instructions in a block
- /// @param block the block with the instructions to emit
- void EmitBlockInstructions(core::ir::Block* block);
-
- /// Emit a return instruction
- /// @param r the return instruction
- void EmitReturn(core::ir::Return* r);
- /// Emit an unreachable instruction
- void EmitUnreachable();
-
- /// Emit a type
- /// @param out the stream to emit too
- /// @param ty the type to emit
- void EmitType(StringStream& out, const core::type::Type* ty);
-
- core::ir::Module& ir_;
-
- /// The buffer holding preamble text
- TextBuffer preamble_buffer_;
-
- /// The current function being emitted
- core::ir::Function* current_function_ = nullptr;
- /// The current block being emitted
- core::ir::Block* current_block_ = nullptr;
-};
+/// @returns the generated GLSL shader on success, or failure
+/// @param module the Tint IR module to generate
+/// @param version the GLSL version information
+Result<std::string> Print(core::ir::Module& module, const Version& version);
} // namespace tint::glsl::writer
diff --git a/src/tint/lang/glsl/writer/writer.cc b/src/tint/lang/glsl/writer/writer.cc
index bc0f045..452bd4b 100644
--- a/src/tint/lang/glsl/writer/writer.cc
+++ b/src/tint/lang/glsl/writer/writer.cc
@@ -71,12 +71,11 @@
}
// Generate the GLSL code.
- auto impl = std::make_unique<Printer>(ir);
- auto result = impl->Generate(options.version);
+ auto result = Print(ir, options.version);
if (!result) {
return result.Failure();
}
- output.glsl = impl->Result();
+ output.glsl = result.Get();
#else
return Failure{"use_tint_ir requires building with TINT_BUILD_WGSL_READER"};
#endif
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
index 6ba1dc1..d84e98a 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
@@ -411,11 +411,8 @@
return EmitEntryPointFunction(func);
}
return EmitFunction(func);
- },
- [&](Default) {
- TINT_ICE() << "unhandled module-scope declaration: " << decl->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
@@ -1106,10 +1103,7 @@
[&](const sem::BuiltinFn* builtin) { return EmitBuiltinCall(out, call, builtin); },
[&](const sem::ValueConversion* conv) { return EmitValueConversion(out, call, conv); },
[&](const sem::ValueConstructor* ctor) { return EmitValueConstructor(out, call, ctor); },
- [&](Default) {
- TINT_ICE() << "unhandled call target: " << target->TypeInfo().name;
- return false;
- });
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitFunctionCall(StringStream& out,
@@ -3119,12 +3113,8 @@
[&](const ast::IdentifierExpression* i) { return EmitIdentifier(out, i); },
[&](const ast::LiteralExpression* l) { return EmitLiteral(out, l); },
[&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(out, m); },
- [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
- std::string(expr->TypeInfo().name));
- return false;
- });
+ [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitIdentifier(StringStream& out, const ast::IdentifierExpression* expr) {
@@ -3335,12 +3325,8 @@
},
[&](const ast::Const*) {
return true; // Constants are embedded at their use
- },
- [&](Default) {
- TINT_ICE() << "unhandled global variable type " << global->TypeInfo().name;
-
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
@@ -3756,12 +3742,8 @@
}
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unhandled constant type: " + constant->Type()->FriendlyName());
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitLiteral(StringStream& out, const ast::LiteralExpression* lit) {
@@ -3793,11 +3775,8 @@
}
diagnostics_.add_error(diag::System::Writer, "unknown integer literal suffix type");
return false;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "unknown literal type");
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitValue(StringStream& out, const core::type::Type* type, int value) {
@@ -3866,12 +3845,8 @@
TINT_DEFER(out << ")" << value);
return EmitType(out, type, core::AddressSpace::kUndefined, core::Access::kUndefined,
"");
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "Invalid type for value emission: " + type->FriendlyName());
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitZeroValue(StringStream& out, const core::type::Type* type) {
@@ -4081,11 +4056,8 @@
[&](const sem::StructMemberAccess* member_access) {
out << member_access->Member()->Name().Name();
return true;
- },
- [&](Default) {
- TINT_ICE() << "unknown member access type: " << sem->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitReturn(const ast::ReturnStatement* stmt) {
@@ -4156,20 +4128,13 @@
[&](const ast::Let* let) { return EmitLet(let); },
[&](const ast::Const*) {
return true; // Constants are embedded at their use
- },
- [&](Default) { //
- TINT_ICE() << "unknown variable type: " << v->variable->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](const ast::ConstAssert*) {
return true; // Not emitted
- },
- [&](Default) { //
- diagnostics_.add_error(diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@@ -4450,11 +4415,8 @@
[&](const core::type::Void*) {
out << "void";
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitTypeAndName(StringStream& out,
@@ -4601,12 +4563,41 @@
return true;
}
+bool ASTPrinter::IsStructOrArrayOfMatrix(const core::type::Type* ty) {
+ if (!ty->IsAnyOf<core::type::Struct, core::type::Array>()) {
+ return false;
+ }
+ return GetOrCreate(is_struct_or_array_of_matrix_, ty, [&]() {
+ Vector<const core::type::Type*, 4> to_visit({ty});
+ while (!to_visit.IsEmpty()) {
+ auto* curr = to_visit.Pop();
+ if (curr->Is<core::type::Matrix>()) {
+ return true;
+ }
+ auto [child_ty, child_count] = curr->Elements();
+ if (child_ty) {
+ to_visit.Push(child_ty);
+ } else {
+ for (uint32_t i = 0; i < child_count; ++i) {
+ to_visit.Push(curr->Element(i));
+ }
+ }
+ }
+ return false;
+ });
+}
+
bool ASTPrinter::EmitLet(const ast::Let* let) {
auto* sem = builder_.Sem().Get(let);
auto* type = sem->Type()->UnwrapRef();
auto out = Line();
- out << "const ";
+
+ // TODO(crbug.com/tint/2059): Workaround DXC bug with const instances of struct/array-of-matrix.
+ if (!IsStructOrArrayOfMatrix(type)) {
+ out << "const ";
+ }
+
if (!EmitTypeAndName(out, type, core::AddressSpace::kUndefined, core::Access::kUndefined,
let->name->symbol.Name())) {
return false;
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
index 2cf5165..0c449b2 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.h
@@ -595,6 +595,10 @@
return builder_.TypeOf(ptr);
}
+ /// @return true if ty is a struct or array with a matrix member (recursively), false otherwise.
+ /// @param ty the type that will be queried.
+ bool IsStructOrArrayOfMatrix(const core::type::Type* ty);
+
ProgramBuilder builder_;
/// Helper functions emitted at the top of the output
@@ -613,6 +617,7 @@
std::unordered_map<const core::type::Matrix*, std::string> dynamic_matrix_scalar_write_;
std::unordered_map<const core::type::Type*, std::string> value_or_one_if_zero_;
std::unordered_set<const core::type::Struct*> emitted_structs_;
+ std::unordered_map<const core::type::Type*, bool> is_struct_or_array_of_matrix_;
};
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/ast_printer/variable_decl_statement_test.cc b/src/tint/lang/hlsl/writer/ast_printer/variable_decl_statement_test.cc
index b64cbbd..8e52ea8 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/variable_decl_statement_test.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/variable_decl_statement_test.cc
@@ -442,5 +442,148 @@
)");
}
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Mat) {
+ auto* C = Const("C", Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(void f() {
+ const float2x3 l = float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f));
+}
+)");
+}
+
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Struct_of_Mat) {
+ Structure("S", Vector{Member("m", ty.mat2x3<f32>())});
+ auto* C = Const("C", Call("S", Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f)));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(struct S {
+ float2x3 m;
+};
+
+void f() {
+ S l = {float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f))};
+}
+)");
+}
+
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Struct_of_Struct_of_Mat) {
+ Structure("S", Vector{Member("m", ty.mat2x3<f32>())});
+ Structure("S2", Vector{Member("s", ty("S"))});
+ auto* C = Const("C", Call("S2", Call("S", Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f))));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(struct S {
+ float2x3 m;
+};
+struct S2 {
+ S s;
+};
+
+void f() {
+ S2 l = {{float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f))}};
+}
+)");
+}
+
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Struct_of_Array_of_Mat) {
+ Structure("S", Vector{Member("m", ty.array(ty.mat2x3<f32>(), 1_u))});
+
+ auto* C = Const("C", Call("S", Call(ty.array(ty.mat2x3<f32>(), 1_u),
+ Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f))));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(struct S {
+ float2x3 m[1];
+};
+
+void f() {
+ S l = {{float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f))}};
+}
+)");
+}
+
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Array_of_Mat) {
+ auto* C = Const("C", Call(ty.array(ty.mat2x3<f32>(), 1_u),
+ Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f)));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(void f() {
+ float2x3 l[1] = {float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f))};
+}
+)");
+}
+
+TEST_F(HlslASTPrinterTest_VariableDecl, Emit_VariableDeclStatement_Const_Array_of_Struct_of_Mat) {
+ Structure("S", Vector{Member("m", ty.mat2x3<f32>())});
+
+ auto* C = Const("C", Call(ty.array(ty("S"), 1_u),
+ Call(ty("S"), Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f))));
+
+ Func("f", tint::Empty, ty.void_(),
+ Vector{
+ Decl(C),
+ Decl(Let("l", Expr(C))),
+ });
+
+ ASTPrinter& gen = Build();
+
+ ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
+
+ EXPECT_EQ(gen.Result(), R"(struct S {
+ float2x3 m;
+};
+
+void f() {
+ S l[1] = {{float2x3(float3(1.0f, 2.0f, 3.0f), float3(4.0f, 5.0f, 6.0f))}};
+}
+)");
+}
+
} // namespace
} // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/ast_raise/localize_struct_array_assignment.cc b/src/tint/lang/hlsl/writer/ast_raise/localize_struct_array_assignment.cc
index d4042f1..c52041d 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/localize_struct_array_assignment.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/localize_struct_array_assignment.cc
@@ -215,12 +215,8 @@
},
[&](const core::type::Pointer* ptr) {
return std::make_pair(ptr->StoreType(), ptr->AddressSpace());
- },
- [&](Default) {
- TINT_ICE() << "Expecting to find variable of type pointer or reference on lhs "
- "of assignment statement";
- return std::pair<const core::type::Type*, core::AddressSpace>{};
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
};
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 0149f74..69a4f31 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -327,12 +327,8 @@
},
[&](const ast::ConstAssert*) {
return true; // Not emitted
- },
- [&](Default) {
- // These are pushed into the entry point by sanitizer transforms.
- TINT_ICE() << "unhandled type: " << decl->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
}
@@ -641,14 +637,12 @@
auto* call = builder_.Sem().Get<sem::Call>(expr);
auto* target = call->Target();
return Switch(
- target, [&](const sem::Function* func) { return EmitFunctionCall(out, call, func); },
+ target, //
+ [&](const sem::Function* func) { return EmitFunctionCall(out, call, func); },
[&](const sem::BuiltinFn* builtin) { return EmitBuiltinCall(out, call, builtin); },
[&](const sem::ValueConversion* conv) { return EmitTypeConversion(out, call, conv); },
- [&](const sem::ValueConstructor* ctor) { return EmitTypeInitializer(out, call, ctor); },
- [&](Default) {
- TINT_ICE() << "unhandled call target: " << target->TypeInfo().name;
- return false;
- });
+ [&](const sem::ValueConstructor* ctor) { return EmitTypeInitializer(out, call, ctor); }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitFunctionCall(StringStream& out,
@@ -1718,12 +1712,8 @@
[&](const core::type::Struct*) {
out << "{}";
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "Invalid type for zero emission: " + type->FriendlyName());
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitConstant(StringStream& out, const core::constant::Value* constant) {
@@ -1844,12 +1834,8 @@
}
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unhandled constant type: " + constant->Type()->FriendlyName());
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitLiteral(StringStream& out, const ast::LiteralExpression* lit) {
@@ -1881,11 +1867,8 @@
}
diagnostics_.add_error(diag::System::Writer, "unknown integer literal suffix type");
return false;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "unknown literal type");
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitExpression(StringStream& out, const ast::Expression* expr) {
@@ -1903,12 +1886,8 @@
[&](const ast::IdentifierExpression* i) { return EmitIdentifier(out, i); },
[&](const ast::LiteralExpression* l) { return EmitLiteral(out, l); },
[&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(out, m); },
- [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); },
- [&](Default) { //
- diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
- std::string(expr->TypeInfo().name));
- return false;
- });
+ [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitStage(StringStream& out, ast::PipelineStage stage) {
@@ -2404,11 +2383,8 @@
}
out << "." << member_access->Member()->Name().Name();
return true;
- },
- [&](Default) {
- TINT_ICE() << "unknown member access type: " << sem->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitReturn(const ast::ReturnStatement* stmt) {
@@ -2490,20 +2466,13 @@
[&](const ast::Let* let) { return EmitLet(let); },
[&](const ast::Const*) {
return true; // Constants are embedded at their use
- },
- [&](Default) { //
- TINT_ICE() << "unknown statement type: " << stmt->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](const ast::ConstAssert*) {
return true; // Not emitted
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitStatements(VectorRef<const ast::Statement*> stmts) {
@@ -2711,11 +2680,8 @@
}
out << ", access::sample";
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer, "invalid texture type");
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
},
[&](const core::type::U32*) {
out << "uint";
@@ -2734,12 +2700,8 @@
[&](const core::type::Void*) {
out << "void";
return true;
- },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown type in EmitType: " + type->FriendlyName());
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool ASTPrinter::EmitTypeAndName(StringStream& out,
diff --git a/src/tint/lang/msl/writer/common/printer_support.cc b/src/tint/lang/msl/writer/common/printer_support.cc
index f11a27e..d9e74c4 100644
--- a/src/tint/lang/msl/writer/common/printer_support.cc
+++ b/src/tint/lang/msl/writer/common/printer_support.cc
@@ -224,10 +224,7 @@
[&](const core::type::Atomic* atomic) { return MslPackedTypeSizeAndAlign(atomic->Type()); },
- [&](Default) {
- TINT_UNREACHABLE() << "Unhandled type " << ty->TypeInfo().name;
- return SizeAndAlign{};
- });
+ TINT_ICE_ON_NO_MATCH);
}
void PrintF32(StringStream& out, float value) {
diff --git a/src/tint/lang/msl/writer/printer/BUILD.bazel b/src/tint/lang/msl/writer/printer/BUILD.bazel
index 806a71d..c72bc54 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.bazel
+++ b/src/tint/lang/msl/writer/printer/BUILD.bazel
@@ -97,7 +97,6 @@
"//src/tint/lang/msl/writer/raise",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
- "//src/tint/utils/generator",
"//src/tint/utils/ice",
"//src/tint/utils/id",
"//src/tint/utils/macros",
diff --git a/src/tint/lang/msl/writer/printer/BUILD.cmake b/src/tint/lang/msl/writer/printer/BUILD.cmake
index 3141c6e..8992d6c 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.cmake
+++ b/src/tint/lang/msl/writer/printer/BUILD.cmake
@@ -102,7 +102,6 @@
tint_lang_msl_writer_raise
tint_utils_containers
tint_utils_diagnostic
- tint_utils_generator
tint_utils_ice
tint_utils_id
tint_utils_macros
diff --git a/src/tint/lang/msl/writer/printer/BUILD.gn b/src/tint/lang/msl/writer/printer/BUILD.gn
index e2d37b2..bfd0975 100644
--- a/src/tint/lang/msl/writer/printer/BUILD.gn
+++ b/src/tint/lang/msl/writer/printer/BUILD.gn
@@ -99,7 +99,6 @@
"${tint_src_dir}/lang/msl/writer/raise",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
- "${tint_src_dir}/utils/generator",
"${tint_src_dir}/utils/ice",
"${tint_src_dir}/utils/id",
"${tint_src_dir}/utils/macros",
diff --git a/src/tint/lang/msl/writer/printer/helper_test.h b/src/tint/lang/msl/writer/printer/helper_test.h
index e6fe363..e0ecbf8 100644
--- a/src/tint/lang/msl/writer/printer/helper_test.h
+++ b/src/tint/lang/msl/writer/printer/helper_test.h
@@ -61,8 +61,6 @@
template <typename BASE>
class MslPrinterTestHelperBase : public BASE {
public:
- MslPrinterTestHelperBase() : writer_(mod) {}
-
/// The test module.
core::ir::Module mod;
/// The test builder.
@@ -71,9 +69,6 @@
core::type::Manager& ty{mod.Types()};
protected:
- /// The MSL writer.
- Printer writer_;
-
/// Validation errors
std::string err_;
@@ -83,18 +78,17 @@
/// Run the writer on the IR module and validate the result.
/// @returns true if generation and validation succeeded
bool Generate() {
- auto raised = raise::Raise(mod);
- if (!raised) {
+ if (auto raised = raise::Raise(mod); !raised) {
err_ = raised.Failure().reason.str();
return false;
}
- auto result = writer_.Generate();
+ auto result = Print(mod);
if (!result) {
err_ = result.Failure().reason.str();
return false;
}
- output_ = writer_.Result();
+ output_ = result.Get();
return true;
}
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 45ea4d1..02fb9c5 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -27,6 +27,8 @@
#include "src/tint/lang/msl/writer/printer/printer.h"
+#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include "src/tint/lang/core/constant/composite.h"
@@ -38,6 +40,7 @@
#include "src/tint/lang/core/ir/if.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
+#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/unreachable.h"
@@ -63,6 +66,7 @@
#include "src/tint/lang/core/type/void.h"
#include "src/tint/lang/msl/writer/common/printer_support.h"
#include "src/tint/utils/containers/map.h"
+#include "src/tint/utils/generator/text_generator.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/text/string.h"
@@ -70,836 +74,979 @@
using namespace tint::core::fluent_types; // NOLINT
namespace tint::msl::writer {
+namespace {
-// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
-#define UNHANDLED_CASE(object_ptr) \
- TINT_UNIMPLEMENTED() << "unhandled case in Switch(): " \
- << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
+/// PIMPL class for the MSL generator
+class Printer : public tint::TextGenerator {
+ public:
+ /// Constructor
+ /// @param module the Tint IR module to generate
+ explicit Printer(core::ir::Module& module) : ir_(module) {}
-Printer::Printer(core::ir::Module& module) : ir_(module) {}
+ /// @returns the generated MSL shader
+ tint::Result<std::string> Generate() {
+ auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "MSL writer");
+ if (!valid) {
+ return std::move(valid.Failure());
+ }
-Printer::~Printer() = default;
+ {
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
+ Line() << "#include <metal_stdlib>";
+ Line() << "using namespace metal;";
+ }
-tint::Result<SuccessType> Printer::Generate() {
- auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "MSL writer");
- if (!valid) {
- return std::move(valid.Failure());
+ // Emit module-scope declarations.
+ EmitBlockInstructions(ir_.root_block);
+
+ // Emit functions.
+ for (auto* func : ir_.functions) {
+ EmitFunction(func);
+ }
+
+ StringStream ss;
+ ss << preamble_buffer_.String() << std::endl << main_buffer_.String();
+ return ss.str();
}
- {
+ private:
+ /// Map of builtin structure to unique generated name
+ std::unordered_map<const core::type::Struct*, std::string> builtin_struct_names_;
+
+ core::ir::Module& ir_;
+
+ /// The buffer holding preamble text
+ TextBuffer preamble_buffer_;
+
+ /// Unique name of the 'TINT_INVARIANT' preprocessor define.
+ /// Non-empty only if an invariant attribute has been generated.
+ std::string invariant_define_name_;
+
+ std::unordered_set<const core::type::Struct*> emitted_structs_;
+
+ /// The current function being emitted
+ core::ir::Function* current_function_ = nullptr;
+ /// The current block being emitted
+ core::ir::Block* current_block_ = nullptr;
+
+ /// Unique name of the tint_array<T, N> template.
+ /// Non-empty only if the template has been generated.
+ std::string array_template_name_;
+
+ /// The representation for an IR pointer type
+ enum class PtrKind {
+ kPtr, // IR pointer is represented in a pointer
+ kRef, // IR pointer is represented in a reference
+ };
+
+ /// The structure for a value held by a 'let', 'var' or parameter.
+ struct VariableValue {
+ Symbol name; // Name of the variable
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
+
+ /// The structure for an inlined value
+ struct InlinedValue {
+ std::string expr;
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
+
+ /// Empty struct used as a sentinel value to indicate that an string expression has been
+ /// consumed by its single place of usage. Attempting to use this value a second time should
+ /// result in an ICE.
+ struct ConsumedValue {};
+
+ using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
+
+ /// IR values to their representation
+ Hashmap<core::ir::Value*, ValueBinding, 32> bindings_;
+
+ /// Values that can be inlined.
+ Hashset<core::ir::Value*, 64> can_inline_;
+
+ /// @returns the name of the templated `tint_array` helper type, generating it if needed
+ const std::string& ArrayTemplateName() {
+ if (!array_template_name_.empty()) {
+ return array_template_name_;
+ }
+
+ array_template_name_ = UniqueIdentifier("tint_array");
+
TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
- Line() << "#include <metal_stdlib>";
- Line() << "using namespace metal;";
- }
+ Line() << "template<typename T, size_t N>";
+ Line() << "struct " << array_template_name_ << " {";
- // Emit module-scope declarations.
- EmitBlockInstructions(ir_.root_block);
+ {
+ ScopedIndent si(current_buffer_);
+ Line()
+ << "const constant T& operator[](size_t i) const constant { return elements[i]; }";
+ for (auto* space : {"device", "thread", "threadgroup"}) {
+ Line() << space << " T& operator[](size_t i) " << space
+ << " { return elements[i]; }";
+ Line() << "const " << space << " T& operator[](size_t i) const " << space
+ << " { return elements[i]; }";
+ }
+ Line() << "T elements[N];";
+ }
+ Line() << "};";
+ Line();
- // Emit functions.
- for (auto* func : ir_.functions) {
- EmitFunction(func);
- }
-
- return Success;
-}
-
-std::string Printer::Result() const {
- StringStream ss;
- ss << preamble_buffer_.String() << std::endl << main_buffer_.String();
- return ss.str();
-}
-
-const std::string& Printer::ArrayTemplateName() {
- if (!array_template_name_.empty()) {
return array_template_name_;
}
- array_template_name_ = UniqueIdentifier("tint_array");
+ /// Emit the function
+ /// @param func the function to emit
+ void EmitFunction(core::ir::Function* func) {
+ TINT_SCOPED_ASSIGNMENT(current_function_, func);
- TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
- Line() << "template<typename T, size_t N>";
- Line() << "struct " << array_template_name_ << " {";
+ {
+ auto out = Line();
- {
- ScopedIndent si(current_buffer_);
- Line() << "const constant T& operator[](size_t i) const constant { return elements[i]; }";
- for (auto* space : {"device", "thread", "threadgroup"}) {
- Line() << space << " T& operator[](size_t i) " << space << " { return elements[i]; }";
- Line() << "const " << space << " T& operator[](size_t i) const " << space
- << " { return elements[i]; }";
+ // TODO(dsinclair): Emit function stage if any
+ // TODO(dsinclair): Handle return type attributes
+
+ EmitType(out, func->ReturnType());
+ out << " " << ir_.NameOf(func).Name() << "() {";
+
+ // TODO(dsinclair): Emit Function parameters
}
- Line() << "T elements[N];";
+ {
+ ScopedIndent si(current_buffer_);
+ EmitBlock(func->Block());
+ }
+
+ Line() << "}";
}
- Line() << "};";
- Line();
- return array_template_name_;
-}
+ /// Emit a block
+ /// @param block the block to emit
+ void EmitBlock(core::ir::Block* block) {
+ MarkInlinable(block);
-void Printer::EmitFunction(core::ir::Function* func) {
- TINT_SCOPED_ASSIGNMENT(current_function_, func);
+ EmitBlockInstructions(block);
+ }
- {
+ /// Emit the instructions in a block
+ /// @param block the block with the instructions to emit
+ void EmitBlockInstructions(core::ir::Block* block) {
+ TINT_SCOPED_ASSIGNMENT(current_block_, block);
+
+ for (auto* inst : *block) {
+ Switch(
+ inst, //
+ [&](core::ir::Binary* b) { EmitBinary(b); }, //
+ [&](core::ir::ExitIf* e) { EmitExitIf(e); }, //
+ [&](core::ir::If* if_) { EmitIf(if_); }, //
+ [&](core::ir::Let* l) { EmitLet(l); }, //
+ [&](core::ir::Load* l) { EmitLoad(l); }, //
+ [&](core::ir::Return* r) { EmitReturn(r); }, //
+ [&](core::ir::Unreachable*) { EmitUnreachable(); }, //
+ [&](core::ir::Var* v) { EmitVar(v); }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
+ }
+
+ /// Emit a binary instruction
+ /// @param b the binary instruction
+ void EmitBinary(core::ir::Binary* b) {
+ if (b->Op() == core::ir::BinaryOp::kEqual) {
+ auto* rhs = b->RHS()->As<core::ir::Constant>();
+ if (rhs && rhs->Type()->Is<core::type::Bool>() &&
+ rhs->Value()->ValueAs<bool>() == false) {
+ // expr == false
+ Bind(b->Result(), "!(" + Expr(b->LHS()) + ")");
+ return;
+ }
+ }
+
+ auto kind = [&] {
+ switch (b->Op()) {
+ case core::ir::BinaryOp::kAdd:
+ return "+";
+ case core::ir::BinaryOp::kSubtract:
+ return "-";
+ case core::ir::BinaryOp::kMultiply:
+ return "*";
+ case core::ir::BinaryOp::kDivide:
+ return "/";
+ case core::ir::BinaryOp::kModulo:
+ return "%";
+ case core::ir::BinaryOp::kAnd:
+ return "&";
+ case core::ir::BinaryOp::kOr:
+ return "|";
+ case core::ir::BinaryOp::kXor:
+ return "^";
+ case core::ir::BinaryOp::kEqual:
+ return "==";
+ case core::ir::BinaryOp::kNotEqual:
+ return "!=";
+ case core::ir::BinaryOp::kLessThan:
+ return "<";
+ case core::ir::BinaryOp::kGreaterThan:
+ return ">";
+ case core::ir::BinaryOp::kLessThanEqual:
+ return "<=";
+ case core::ir::BinaryOp::kGreaterThanEqual:
+ return ">=";
+ case core::ir::BinaryOp::kShiftLeft:
+ return "<<";
+ case core::ir::BinaryOp::kShiftRight:
+ return ">>";
+ }
+ return "<error>";
+ };
+
+ StringStream str;
+ str << "(" << Expr(b->LHS()) << " " << kind() << " " + Expr(b->RHS()) << ")";
+
+ Bind(b->Result(), str.str());
+ }
+
+ /// Emit a load instruction
+ /// @param l the load instruction
+ void EmitLoad(core::ir::Load* l) {
+ // Force loads to be bound as inlines
+ bindings_.Add(l->Result(), InlinedValue{Expr(l->From()), PtrKind::kRef});
+ }
+
+ /// Emit a var instruction
+ /// @param v the var instruction
+ void EmitVar(core::ir::Var* v) {
auto out = Line();
- // TODO(dsinclair): Emit function stage if any
- // TODO(dsinclair): Handle return type attributes
+ auto* ptr = v->Result()->Type()->As<core::type::Pointer>();
+ TINT_ASSERT_OR_RETURN(ptr);
- EmitType(out, func->ReturnType());
- out << " " << ir_.NameOf(func).Name() << "() {";
-
- // TODO(dsinclair): Emit Function parameters
- }
- {
- ScopedIndent si(current_buffer_);
- EmitBlock(func->Block());
- }
-
- Line() << "}";
-}
-
-void Printer::EmitBlock(core::ir::Block* block) {
- MarkInlinable(block);
-
- EmitBlockInstructions(block);
-}
-
-void Printer::EmitBlockInstructions(core::ir::Block* block) {
- TINT_SCOPED_ASSIGNMENT(current_block_, block);
-
- for (auto* inst : *block) {
- Switch(
- inst, //
- [&](core::ir::Binary* b) { EmitBinary(b); }, //
- [&](core::ir::ExitIf* e) { EmitExitIf(e); }, //
- [&](core::ir::If* if_) { EmitIf(if_); }, //
- [&](core::ir::Let* l) { EmitLet(l); }, //
- [&](core::ir::Load* l) { EmitLoad(l); }, //
- [&](core::ir::Return* r) { EmitReturn(r); }, //
- [&](core::ir::Unreachable*) { EmitUnreachable(); }, //
- [&](core::ir::Var* v) { EmitVar(v); }, //
- [&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
- }
-}
-
-void Printer::EmitBinary(core::ir::Binary* b) {
- if (b->Op() == core::ir::BinaryOp::kEqual) {
- auto* rhs = b->RHS()->As<core::ir::Constant>();
- if (rhs && rhs->Type()->Is<core::type::Bool>() && rhs->Value()->ValueAs<bool>() == false) {
- // expr == false
- Bind(b->Result(), "!(" + Expr(b->LHS()) + ")");
- return;
- }
- }
-
- auto kind = [&] {
- switch (b->Op()) {
- case core::ir::BinaryOp::kAdd:
- return "+";
- case core::ir::BinaryOp::kSubtract:
- return "-";
- case core::ir::BinaryOp::kMultiply:
- return "*";
- case core::ir::BinaryOp::kDivide:
- return "/";
- case core::ir::BinaryOp::kModulo:
- return "%";
- case core::ir::BinaryOp::kAnd:
- return "&";
- case core::ir::BinaryOp::kOr:
- return "|";
- case core::ir::BinaryOp::kXor:
- return "^";
- case core::ir::BinaryOp::kEqual:
- return "==";
- case core::ir::BinaryOp::kNotEqual:
- return "!=";
- case core::ir::BinaryOp::kLessThan:
- return "<";
- case core::ir::BinaryOp::kGreaterThan:
- return ">";
- case core::ir::BinaryOp::kLessThanEqual:
- return "<=";
- case core::ir::BinaryOp::kGreaterThanEqual:
- return ">=";
- case core::ir::BinaryOp::kShiftLeft:
- return "<<";
- case core::ir::BinaryOp::kShiftRight:
- return ">>";
- }
- return "<error>";
- };
-
- StringStream str;
- str << "(" << Expr(b->LHS()) << " " << kind() << " " + Expr(b->RHS()) << ")";
-
- Bind(b->Result(), str.str());
-}
-
-void Printer::EmitLoad(core::ir::Load* l) {
- // Force loads to be bound as inlines
- bindings_.Add(l->Result(), InlinedValue{Expr(l->From()), PtrKind::kRef});
-}
-
-void Printer::EmitVar(core::ir::Var* v) {
- auto out = Line();
-
- auto* ptr = v->Result()->Type()->As<core::type::Pointer>();
- TINT_ASSERT_OR_RETURN(ptr);
-
- auto space = ptr->AddressSpace();
- switch (space) {
- case core::AddressSpace::kFunction:
- case core::AddressSpace::kHandle:
- break;
- case core::AddressSpace::kPrivate:
- out << "thread ";
- break;
- case core::AddressSpace::kWorkgroup:
- out << "threadgroup ";
- break;
- default:
- TINT_ICE() << "unhandled variable address space";
- return;
- }
-
- auto name = ir_.NameOf(v);
-
- EmitType(out, ptr->UnwrapPtr());
- out << " " << name.Name();
-
- if (v->Initializer()) {
- out << " = " << Expr(v->Initializer());
- } else if (space == core::AddressSpace::kPrivate || space == core::AddressSpace::kFunction ||
- space == core::AddressSpace::kUndefined) {
- out << " = ";
- EmitZeroValue(out, ptr->UnwrapPtr());
- }
- out << ";";
-
- Bind(v->Result(), name, PtrKind::kRef);
-}
-
-void Printer::EmitLet(core::ir::Let* l) {
- Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
-}
-
-void Printer::EmitIf(core::ir::If* if_) {
- // Emit any nodes that need to be used as PHI nodes
- for (auto* phi : if_->Results()) {
- if (!ir_.NameOf(phi).IsValid()) {
- ir_.SetName(phi, ir_.symbols.New());
- }
-
- auto name = ir_.NameOf(phi);
-
- auto out = Line();
- EmitType(out, phi->Type());
- out << " " << name.Name() << ";";
-
- Bind(phi, name);
- }
-
- Line() << "if (" << Expr(if_->Condition()) << ") {";
-
- {
- ScopedIndent si(current_buffer_);
- EmitBlockInstructions(if_->True());
- }
-
- if (if_->False() && !if_->False()->IsEmpty()) {
- Line() << "} else {";
-
- ScopedIndent si(current_buffer_);
- EmitBlockInstructions(if_->False());
- }
-
- Line() << "}";
-}
-
-void Printer::EmitExitIf(core::ir::ExitIf* e) {
- auto results = e->If()->Results();
- auto args = e->Args();
- for (size_t i = 0; i < e->Args().Length(); ++i) {
- auto* phi = results[i];
- auto* val = args[i];
-
- Line() << ir_.NameOf(phi).Name() << " = " << Expr(val) << ";";
- }
-}
-
-void Printer::EmitReturn(core::ir::Return* r) {
- // If this return has no arguments and the current block is for the function which is
- // being returned, skip the return.
- if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
- return;
- }
-
- auto out = Line();
- out << "return";
- if (!r->Args().IsEmpty()) {
- out << " " << Expr(r->Args().Front());
- }
- out << ";";
-}
-
-void Printer::EmitUnreachable() {
- Line() << "/* unreachable */";
-}
-
-void Printer::EmitAddressSpace(StringStream& out, core::AddressSpace sc) {
- switch (sc) {
- case core::AddressSpace::kFunction:
- case core::AddressSpace::kPrivate:
- case core::AddressSpace::kHandle:
- out << "thread";
- break;
- case core::AddressSpace::kWorkgroup:
- out << "threadgroup";
- break;
- case core::AddressSpace::kStorage:
- out << "device";
- break;
- case core::AddressSpace::kUniform:
- out << "constant";
- break;
- default:
- TINT_ICE() << "unhandled address space: " << sc;
- break;
- }
-}
-
-void Printer::EmitType(StringStream& out, const core::type::Type* ty) {
- tint::Switch(
- ty, //
- [&](const core::type::Bool*) { out << "bool"; }, //
- [&](const core::type::Void*) { out << "void"; }, //
- [&](const core::type::F32*) { out << "float"; }, //
- [&](const core::type::F16*) { out << "half"; }, //
- [&](const core::type::I32*) { out << "int"; }, //
- [&](const core::type::U32*) { out << "uint"; }, //
- [&](const core::type::Array* arr) { EmitArrayType(out, arr); },
- [&](const core::type::Vector* vec) { EmitVectorType(out, vec); },
- [&](const core::type::Matrix* mat) { EmitMatrixType(out, mat); },
- [&](const core::type::Atomic* atomic) { EmitAtomicType(out, atomic); },
- [&](const core::type::Pointer* ptr) { EmitPointerType(out, ptr); },
- [&](const core::type::Sampler*) { out << "sampler"; }, //
- [&](const core::type::Texture* tex) { EmitTextureType(out, tex); },
- [&](const core::type::Struct* str) {
- out << StructName(str);
-
- TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
- EmitStructType(str);
- },
- [&](Default) { UNHANDLED_CASE(ty); });
-}
-
-void Printer::EmitPointerType(StringStream& out, const core::type::Pointer* ptr) {
- if (ptr->Access() == core::Access::kRead) {
- out << "const ";
- }
- EmitAddressSpace(out, ptr->AddressSpace());
- out << " ";
- EmitType(out, ptr->StoreType());
- out << "*";
-}
-
-void Printer::EmitAtomicType(StringStream& out, const core::type::Atomic* atomic) {
- if (atomic->Type()->Is<core::type::I32>()) {
- out << "atomic_int";
- return;
- }
- if (TINT_LIKELY(atomic->Type()->Is<core::type::U32>())) {
- out << "atomic_uint";
- return;
- }
- TINT_ICE() << "unhandled atomic type " << atomic->Type()->FriendlyName();
-}
-
-void Printer::EmitArrayType(StringStream& out, const core::type::Array* arr) {
- out << ArrayTemplateName() << "<";
- EmitType(out, arr->ElemType());
- out << ", ";
- if (arr->Count()->Is<core::type::RuntimeArrayCount>()) {
- out << "1";
- } else {
- auto count = arr->ConstantCount();
- if (!count) {
- TINT_ICE() << core::type::Array::kErrExpectedConstantCount;
- return;
- }
- out << count.value();
- }
- out << ">";
-}
-
-void Printer::EmitVectorType(StringStream& out, const core::type::Vector* vec) {
- if (vec->Packed()) {
- out << "packed_";
- }
- EmitType(out, vec->type());
- out << vec->Width();
-}
-
-void Printer::EmitMatrixType(StringStream& out, const core::type::Matrix* mat) {
- EmitType(out, mat->type());
- out << mat->columns() << "x" << mat->rows();
-}
-
-void Printer::EmitTextureType(StringStream& out, const core::type::Texture* tex) {
- if (TINT_UNLIKELY(tex->Is<core::type::ExternalTexture>())) {
- TINT_ICE() << "Multiplanar external texture transform was not run.";
- return;
- }
-
- if (tex->IsAnyOf<core::type::DepthTexture, core::type::DepthMultisampledTexture>()) {
- out << "depth";
- } else {
- out << "texture";
- }
-
- switch (tex->dim()) {
- case core::type::TextureDimension::k1d:
- out << "1d";
- break;
- case core::type::TextureDimension::k2d:
- out << "2d";
- break;
- case core::type::TextureDimension::k2dArray:
- out << "2d_array";
- break;
- case core::type::TextureDimension::k3d:
- out << "3d";
- break;
- case core::type::TextureDimension::kCube:
- out << "cube";
- break;
- case core::type::TextureDimension::kCubeArray:
- out << "cube_array";
- break;
- default:
- TINT_ICE() << "invalid texture dimensions";
- return;
- }
- if (tex->IsAnyOf<core::type::MultisampledTexture, core::type::DepthMultisampledTexture>()) {
- out << "_ms";
- }
- out << "<";
- TINT_DEFER(out << ">");
-
- tint::Switch(
- tex, //
- [&](const core::type::DepthTexture*) { out << "float, access::sample"; },
- [&](const core::type::DepthMultisampledTexture*) { out << "float, access::read"; },
- [&](const core::type::StorageTexture* storage) {
- EmitType(out, storage->type());
- out << ", ";
-
- std::string access_str;
- if (storage->access() == core::Access::kRead) {
- out << "access::read";
- } else if (storage->access() == core::Access::kWrite) {
- out << "access::write";
- } else {
- TINT_ICE() << "invalid access control for storage texture";
+ auto space = ptr->AddressSpace();
+ switch (space) {
+ case core::AddressSpace::kFunction:
+ case core::AddressSpace::kHandle:
+ break;
+ case core::AddressSpace::kPrivate:
+ out << "thread ";
+ break;
+ case core::AddressSpace::kWorkgroup:
+ out << "threadgroup ";
+ break;
+ default:
+ TINT_ICE() << "unhandled variable address space";
return;
- }
- },
- [&](const core::type::MultisampledTexture* ms) {
- EmitType(out, ms->type());
- out << ", access::read";
- },
- [&](const core::type::SampledTexture* sampled) {
- EmitType(out, sampled->type());
- out << ", access::sample";
- },
- [&](Default) { TINT_ICE() << "invalid texture type"; });
-}
-
-void Printer::EmitStructType(const core::type::Struct* str) {
- auto it = emitted_structs_.emplace(str);
- if (!it.second) {
- return;
- }
-
- // This does not append directly to the preamble because a struct may require other
- // structs, or the array template, to get emitted before it. So, the struct emits into a
- // temporary text buffer, then anything it depends on will emit to the preamble first,
- // and then it copies the text buffer into the preamble.
- TextBuffer str_buf;
- Line(&str_buf) << "struct " << StructName(str) << " {";
-
- bool is_host_shareable = str->IsHostShareable();
-
- // Emits a `/* 0xnnnn */` byte offset comment for a struct member.
- auto add_byte_offset_comment = [&](StringStream& out, uint32_t offset) {
- std::ios_base::fmtflags saved_flag_state(out.flags());
- out << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset << " */ ";
- out.flags(saved_flag_state);
- };
-
- auto add_padding = [&](uint32_t size, uint32_t msl_offset) {
- std::string name;
- do {
- name = UniqueIdentifier("tint_pad");
- } while (str->FindMember(ir_.symbols.Get(name)));
-
- auto out = Line(&str_buf);
- add_byte_offset_comment(out, msl_offset);
- out << ArrayTemplateName() << "<int8_t, " << size << "> " << name << ";";
- };
-
- str_buf.IncrementIndent();
-
- uint32_t msl_offset = 0;
- for (auto* mem : str->Members()) {
- auto out = Line(&str_buf);
- auto mem_name = mem->Name().Name();
- auto ir_offset = mem->Offset();
-
- if (is_host_shareable) {
- if (TINT_UNLIKELY(ir_offset < msl_offset)) {
- // Unimplementable layout
- TINT_ICE() << "Structure member offset (" << ir_offset << ") is behind MSL offset ("
- << msl_offset << ")";
- return;
- }
-
- // Generate padding if required
- if (auto padding = ir_offset - msl_offset) {
- add_padding(padding, msl_offset);
- msl_offset += padding;
- }
-
- add_byte_offset_comment(out, msl_offset);
}
- auto* ty = mem->Type();
+ auto name = ir_.NameOf(v);
- EmitType(out, ty);
- out << " " << mem_name;
+ EmitType(out, ptr->UnwrapPtr());
+ out << " " << name.Name();
- // Emit attributes
- auto& attributes = mem->Attributes();
-
- if (auto builtin = attributes.builtin) {
- auto name = BuiltinToAttribute(builtin.value());
- if (name.empty()) {
- TINT_ICE() << "unknown builtin";
- return;
- }
- out << " [[" << name << "]]";
+ if (v->Initializer()) {
+ out << " = " << Expr(v->Initializer());
+ } else if (space == core::AddressSpace::kPrivate ||
+ space == core::AddressSpace::kFunction ||
+ space == core::AddressSpace::kUndefined) {
+ out << " = ";
+ EmitZeroValue(out, ptr->UnwrapPtr());
}
-
- if (auto location = attributes.location) {
- auto& pipeline_stage_uses = str->PipelineStageUses();
- if (TINT_UNLIKELY(pipeline_stage_uses.size() != 1)) {
- TINT_ICE() << "invalid entry point IO struct uses";
- return;
- }
-
- if (pipeline_stage_uses.count(core::type::PipelineStageUsage::kVertexInput)) {
- out << " [[attribute(" + std::to_string(location.value()) + ")]]";
- } else if (pipeline_stage_uses.count(core::type::PipelineStageUsage::kVertexOutput)) {
- out << " [[user(locn" + std::to_string(location.value()) + ")]]";
- } else if (pipeline_stage_uses.count(core::type::PipelineStageUsage::kFragmentInput)) {
- out << " [[user(locn" + std::to_string(location.value()) + ")]]";
- } else if (TINT_LIKELY(pipeline_stage_uses.count(
- core::type::PipelineStageUsage::kFragmentOutput))) {
- out << " [[color(" + std::to_string(location.value()) + ")]]";
- } else {
- TINT_ICE() << "invalid use of location decoration";
- return;
- }
- }
-
- if (auto interpolation = attributes.interpolation) {
- auto name = InterpolationToAttribute(interpolation->type, interpolation->sampling);
- if (name.empty()) {
- TINT_ICE() << "unknown interpolation attribute";
- return;
- }
- out << " [[" << name << "]]";
- }
-
- if (attributes.invariant) {
- invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
- out << " " << invariant_define_name_;
- }
-
out << ";";
- if (is_host_shareable) {
- // Calculate new MSL offset
- auto size_align = MslPackedTypeSizeAndAlign(ty);
- if (TINT_UNLIKELY(msl_offset % size_align.align)) {
- TINT_ICE() << "Misaligned MSL structure member " << mem_name << " : "
- << ty->FriendlyName() << " offset: " << msl_offset
- << " align: " << size_align.align;
- return;
+ Bind(v->Result(), name, PtrKind::kRef);
+ }
+
+ /// Emit a let instruction
+ /// @param l the let instruction
+ void EmitLet(core::ir::Let* l) {
+ Bind(l->Result(), Expr(l->Value(), PtrKind::kPtr), PtrKind::kPtr);
+ }
+
+ /// Emit an if instruction
+ /// @param if_ the if instruction
+ void EmitIf(core::ir::If* if_) {
+ // Emit any nodes that need to be used as PHI nodes
+ for (auto* phi : if_->Results()) {
+ if (!ir_.NameOf(phi).IsValid()) {
+ ir_.SetName(phi, ir_.symbols.New());
}
- msl_offset += size_align.size;
+
+ auto name = ir_.NameOf(phi);
+
+ auto out = Line();
+ EmitType(out, phi->Type());
+ out << " " << name.Name() << ";";
+
+ Bind(phi, name);
+ }
+
+ Line() << "if (" << Expr(if_->Condition()) << ") {";
+
+ {
+ ScopedIndent si(current_buffer_);
+ EmitBlockInstructions(if_->True());
+ }
+
+ if (if_->False() && !if_->False()->IsEmpty()) {
+ Line() << "} else {";
+
+ ScopedIndent si(current_buffer_);
+ EmitBlockInstructions(if_->False());
+ }
+
+ Line() << "}";
+ }
+
+ /// Emit an exit-if instruction
+ /// @param e the exit-if instruction
+ void EmitExitIf(core::ir::ExitIf* e) {
+ auto results = e->If()->Results();
+ auto args = e->Args();
+ for (size_t i = 0; i < e->Args().Length(); ++i) {
+ auto* phi = results[i];
+ auto* val = args[i];
+
+ Line() << ir_.NameOf(phi).Name() << " = " << Expr(val) << ";";
}
}
- if (is_host_shareable && str->Size() != msl_offset) {
- add_padding(str->Size() - msl_offset, msl_offset);
+ /// Emit a return instruction
+ /// @param r the return instruction
+ void EmitReturn(core::ir::Return* r) {
+ // If this return has no arguments and the current block is for the function which is
+ // being returned, skip the return.
+ if (current_block_ == current_function_->Block() && r->Args().IsEmpty()) {
+ return;
+ }
+
+ auto out = Line();
+ out << "return";
+ if (!r->Args().IsEmpty()) {
+ out << " " << Expr(r->Args().Front());
+ }
+ out << ";";
}
- str_buf.DecrementIndent();
- Line(&str_buf) << "};";
+ /// Emit an unreachable instruction
+ void EmitUnreachable() { Line() << "/* unreachable */"; }
- preamble_buffer_.Append(str_buf);
-}
-
-void Printer::EmitConstant(StringStream& out, core::ir::Constant* c) {
- EmitConstant(out, c->Value());
-}
-
-void Printer::EmitConstant(StringStream& out, const core::constant::Value* c) {
- auto emit_values = [&](uint32_t count) {
- for (size_t i = 0; i < count; i++) {
- if (i > 0) {
- out << ", ";
- }
- EmitConstant(out, c->Index(i));
+ /// Handles generating a address space
+ /// @param out the output of the type stream
+ /// @param sc the address space to generate
+ void EmitAddressSpace(StringStream& out, core::AddressSpace sc) {
+ switch (sc) {
+ case core::AddressSpace::kFunction:
+ case core::AddressSpace::kPrivate:
+ case core::AddressSpace::kHandle:
+ out << "thread";
+ break;
+ case core::AddressSpace::kWorkgroup:
+ out << "threadgroup";
+ break;
+ case core::AddressSpace::kStorage:
+ out << "device";
+ break;
+ case core::AddressSpace::kUniform:
+ out << "constant";
+ break;
+ default:
+ TINT_ICE() << "unhandled address space: " << sc;
+ break;
}
- };
+ }
- tint::Switch(
- c->Type(), //
- [&](const core::type::Bool*) { out << (c->ValueAs<bool>() ? "true" : "false"); },
- [&](const core::type::I32*) { PrintI32(out, c->ValueAs<i32>()); },
- [&](const core::type::U32*) { out << c->ValueAs<u32>() << "u"; },
- [&](const core::type::F32*) { PrintF32(out, c->ValueAs<f32>()); },
- [&](const core::type::F16*) { PrintF16(out, c->ValueAs<f16>()); },
- [&](const core::type::Vector* v) {
- EmitType(out, v);
+ /// Emit a type
+ /// @param out the stream to emit too
+ /// @param ty the type to emit
+ void EmitType(StringStream& out, const core::type::Type* ty) {
+ tint::Switch(
+ ty, //
+ [&](const core::type::Bool*) { out << "bool"; }, //
+ [&](const core::type::Void*) { out << "void"; }, //
+ [&](const core::type::F32*) { out << "float"; }, //
+ [&](const core::type::F16*) { out << "half"; }, //
+ [&](const core::type::I32*) { out << "int"; }, //
+ [&](const core::type::U32*) { out << "uint"; }, //
+ [&](const core::type::Array* arr) { EmitArrayType(out, arr); },
+ [&](const core::type::Vector* vec) { EmitVectorType(out, vec); },
+ [&](const core::type::Matrix* mat) { EmitMatrixType(out, mat); },
+ [&](const core::type::Atomic* atomic) { EmitAtomicType(out, atomic); },
+ [&](const core::type::Pointer* ptr) { EmitPointerType(out, ptr); },
+ [&](const core::type::Sampler*) { out << "sampler"; }, //
+ [&](const core::type::Texture* tex) { EmitTextureType(out, tex); },
+ [&](const core::type::Struct* str) {
+ out << StructName(str);
- ScopedParen sp(out);
- if (auto* splat = c->As<core::constant::Splat>()) {
- EmitConstant(out, splat->el);
- return;
- }
- emit_values(v->Width());
- },
- [&](const core::type::Matrix* m) {
- EmitType(out, m);
- ScopedParen sp(out);
- emit_values(m->columns());
- },
- [&](const core::type::Array* a) {
- EmitType(out, a);
- out << "{";
- TINT_DEFER(out << "}");
+ TINT_SCOPED_ASSIGNMENT(current_buffer_, &preamble_buffer_);
+ EmitStructType(str);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
- if (c->AllZero()) {
- return;
- }
+ /// Handles generating a pointer declaration
+ /// @param out the output stream
+ /// @param ptr the pointer to emit
+ void EmitPointerType(StringStream& out, const core::type::Pointer* ptr) {
+ if (ptr->Access() == core::Access::kRead) {
+ out << "const ";
+ }
+ EmitAddressSpace(out, ptr->AddressSpace());
+ out << " ";
+ EmitType(out, ptr->StoreType());
+ out << "*";
+ }
- auto count = a->ConstantCount();
+ /// Handles generating an atomic declaration
+ /// @param out the output stream
+ /// @param atomic the atomic to emit
+ void EmitAtomicType(StringStream& out, const core::type::Atomic* atomic) {
+ if (atomic->Type()->Is<core::type::I32>()) {
+ out << "atomic_int";
+ return;
+ }
+ if (TINT_LIKELY(atomic->Type()->Is<core::type::U32>())) {
+ out << "atomic_uint";
+ return;
+ }
+ TINT_ICE() << "unhandled atomic type " << atomic->Type()->FriendlyName();
+ }
+
+ /// Handles generating an array declaration
+ /// @param out the output stream
+ /// @param arr the array to emit
+ void EmitArrayType(StringStream& out, const core::type::Array* arr) {
+ out << ArrayTemplateName() << "<";
+ EmitType(out, arr->ElemType());
+ out << ", ";
+ if (arr->Count()->Is<core::type::RuntimeArrayCount>()) {
+ out << "1";
+ } else {
+ auto count = arr->ConstantCount();
if (!count) {
TINT_ICE() << core::type::Array::kErrExpectedConstantCount;
return;
}
- emit_values(*count);
- },
- [&](const core::type::Struct* s) {
- EmitStructType(s);
- out << StructName(s) << "{";
- TINT_DEFER(out << "}");
+ out << count.value();
+ }
+ out << ">";
+ }
- if (c->AllZero()) {
+ /// Handles generating a vector declaration
+ /// @param out the output stream
+ /// @param vec the vector to emit
+ void EmitVectorType(StringStream& out, const core::type::Vector* vec) {
+ if (vec->Packed()) {
+ out << "packed_";
+ }
+ EmitType(out, vec->type());
+ out << vec->Width();
+ }
+
+ /// Handles generating a matrix declaration
+ /// @param out the output stream
+ /// @param mat the matrix to emit
+ void EmitMatrixType(StringStream& out, const core::type::Matrix* mat) {
+ EmitType(out, mat->type());
+ out << mat->columns() << "x" << mat->rows();
+ }
+
+ /// Handles generating a texture declaration
+ /// @param out the output stream
+ /// @param tex the texture to emit
+ void EmitTextureType(StringStream& out, const core::type::Texture* tex) {
+ if (TINT_UNLIKELY(tex->Is<core::type::ExternalTexture>())) {
+ TINT_ICE() << "Multiplanar external texture transform was not run.";
+ return;
+ }
+
+ if (tex->IsAnyOf<core::type::DepthTexture, core::type::DepthMultisampledTexture>()) {
+ out << "depth";
+ } else {
+ out << "texture";
+ }
+
+ switch (tex->dim()) {
+ case core::type::TextureDimension::k1d:
+ out << "1d";
+ break;
+ case core::type::TextureDimension::k2d:
+ out << "2d";
+ break;
+ case core::type::TextureDimension::k2dArray:
+ out << "2d_array";
+ break;
+ case core::type::TextureDimension::k3d:
+ out << "3d";
+ break;
+ case core::type::TextureDimension::kCube:
+ out << "cube";
+ break;
+ case core::type::TextureDimension::kCubeArray:
+ out << "cube_array";
+ break;
+ default:
+ TINT_ICE() << "invalid texture dimensions";
return;
+ }
+ if (tex->IsAnyOf<core::type::MultisampledTexture, core::type::DepthMultisampledTexture>()) {
+ out << "_ms";
+ }
+ out << "<";
+ TINT_DEFER(out << ">");
+
+ tint::Switch(
+ tex, //
+ [&](const core::type::DepthTexture*) { out << "float, access::sample"; },
+ [&](const core::type::DepthMultisampledTexture*) { out << "float, access::read"; },
+ [&](const core::type::StorageTexture* storage) {
+ EmitType(out, storage->type());
+ out << ", ";
+
+ std::string access_str;
+ if (storage->access() == core::Access::kRead) {
+ out << "access::read";
+ } else if (storage->access() == core::Access::kWrite) {
+ out << "access::write";
+ } else {
+ TINT_ICE() << "invalid access control for storage texture";
+ return;
+ }
+ },
+ [&](const core::type::MultisampledTexture* ms) {
+ EmitType(out, ms->type());
+ out << ", access::read";
+ },
+ [&](const core::type::SampledTexture* sampled) {
+ EmitType(out, sampled->type());
+ out << ", access::sample";
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ /// Handles generating a struct declaration. If the structure has already been emitted, then
+ /// this function will simply return without emitting anything.
+ /// @param str the struct to generate
+ void EmitStructType(const core::type::Struct* str) {
+ auto it = emitted_structs_.emplace(str);
+ if (!it.second) {
+ return;
+ }
+
+ // This does not append directly to the preamble because a struct may require other
+ // structs, or the array template, to get emitted before it. So, the struct emits into a
+ // temporary text buffer, then anything it depends on will emit to the preamble first,
+ // and then it copies the text buffer into the preamble.
+ TextBuffer str_buf;
+ Line(&str_buf) << "struct " << StructName(str) << " {";
+
+ bool is_host_shareable = str->IsHostShareable();
+
+ // Emits a `/* 0xnnnn */` byte offset comment for a struct member.
+ auto add_byte_offset_comment = [&](StringStream& out, uint32_t offset) {
+ std::ios_base::fmtflags saved_flag_state(out.flags());
+ out << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset << " */ ";
+ out.flags(saved_flag_state);
+ };
+
+ auto add_padding = [&](uint32_t size, uint32_t msl_offset) {
+ std::string name;
+ do {
+ name = UniqueIdentifier("tint_pad");
+ } while (str->FindMember(ir_.symbols.Get(name)));
+
+ auto out = Line(&str_buf);
+ add_byte_offset_comment(out, msl_offset);
+ out << ArrayTemplateName() << "<int8_t, " << size << "> " << name << ";";
+ };
+
+ str_buf.IncrementIndent();
+
+ uint32_t msl_offset = 0;
+ for (auto* mem : str->Members()) {
+ auto out = Line(&str_buf);
+ auto mem_name = mem->Name().Name();
+ auto ir_offset = mem->Offset();
+
+ if (is_host_shareable) {
+ if (TINT_UNLIKELY(ir_offset < msl_offset)) {
+ // Unimplementable layout
+ TINT_ICE() << "Structure member offset (" << ir_offset
+ << ") is behind MSL offset (" << msl_offset << ")";
+ return;
+ }
+
+ // Generate padding if required
+ if (auto padding = ir_offset - msl_offset) {
+ add_padding(padding, msl_offset);
+ msl_offset += padding;
+ }
+
+ add_byte_offset_comment(out, msl_offset);
}
- auto members = s->Members();
- for (size_t i = 0; i < members.Length(); i++) {
+ auto* ty = mem->Type();
+
+ EmitType(out, ty);
+ out << " " << mem_name;
+
+ // Emit attributes
+ auto& attributes = mem->Attributes();
+
+ if (auto builtin = attributes.builtin) {
+ auto name = BuiltinToAttribute(builtin.value());
+ if (name.empty()) {
+ TINT_ICE() << "unknown builtin";
+ return;
+ }
+ out << " [[" << name << "]]";
+ }
+
+ if (auto location = attributes.location) {
+ auto& pipeline_stage_uses = str->PipelineStageUses();
+ if (TINT_UNLIKELY(pipeline_stage_uses.size() != 1)) {
+ TINT_ICE() << "invalid entry point IO struct uses";
+ return;
+ }
+
+ if (pipeline_stage_uses.count(core::type::PipelineStageUsage::kVertexInput)) {
+ out << " [[attribute(" + std::to_string(location.value()) + ")]]";
+ } else if (pipeline_stage_uses.count(
+ core::type::PipelineStageUsage::kVertexOutput)) {
+ out << " [[user(locn" + std::to_string(location.value()) + ")]]";
+ } else if (pipeline_stage_uses.count(
+ core::type::PipelineStageUsage::kFragmentInput)) {
+ out << " [[user(locn" + std::to_string(location.value()) + ")]]";
+ } else if (TINT_LIKELY(pipeline_stage_uses.count(
+ core::type::PipelineStageUsage::kFragmentOutput))) {
+ out << " [[color(" + std::to_string(location.value()) + ")]]";
+ } else {
+ TINT_ICE() << "invalid use of location decoration";
+ return;
+ }
+ }
+
+ if (auto interpolation = attributes.interpolation) {
+ auto name = InterpolationToAttribute(interpolation->type, interpolation->sampling);
+ if (name.empty()) {
+ TINT_ICE() << "unknown interpolation attribute";
+ return;
+ }
+ out << " [[" << name << "]]";
+ }
+
+ if (attributes.invariant) {
+ invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
+ out << " " << invariant_define_name_;
+ }
+
+ out << ";";
+
+ if (is_host_shareable) {
+ // Calculate new MSL offset
+ auto size_align = MslPackedTypeSizeAndAlign(ty);
+ if (TINT_UNLIKELY(msl_offset % size_align.align)) {
+ TINT_ICE() << "Misaligned MSL structure member " << mem_name << " : "
+ << ty->FriendlyName() << " offset: " << msl_offset
+ << " align: " << size_align.align;
+ return;
+ }
+ msl_offset += size_align.size;
+ }
+ }
+
+ if (is_host_shareable && str->Size() != msl_offset) {
+ add_padding(str->Size() - msl_offset, msl_offset);
+ }
+
+ str_buf.DecrementIndent();
+ Line(&str_buf) << "};";
+
+ preamble_buffer_.Append(str_buf);
+ }
+
+ /// Handles core::ir::Constant values
+ /// @param out the stream to write the constant too
+ /// @param c the constant to emit
+ void EmitConstant(StringStream& out, core::ir::Constant* c) { EmitConstant(out, c->Value()); }
+
+ /// Handles core::constant::Value values
+ /// @param out the stream to write the constant too
+ /// @param c the constant to emit
+ void EmitConstant(StringStream& out, const core::constant::Value* c) {
+ auto emit_values = [&](uint32_t count) {
+ for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
- out << "." << members[i]->Name().Name() << "=";
EmitConstant(out, c->Index(i));
}
- },
- [&](Default) { UNHANDLED_CASE(c->Type()); });
-}
+ };
-void Printer::EmitZeroValue(StringStream& out, const core::type::Type* ty) {
- Switch(
- ty, [&](const core::type::Bool*) { out << "false"; }, //
- [&](const core::type::F16*) { out << "0.0h"; }, //
- [&](const core::type::F32*) { out << "0.0f"; }, //
- [&](const core::type::I32*) { out << "0"; }, //
- [&](const core::type::U32*) { out << "0u"; }, //
- [&](const core::type::Vector* vec) { EmitZeroValue(out, vec->type()); }, //
- [&](const core::type::Matrix* mat) {
- EmitType(out, mat);
+ tint::Switch(
+ c->Type(), //
+ [&](const core::type::Bool*) { out << (c->ValueAs<bool>() ? "true" : "false"); },
+ [&](const core::type::I32*) { PrintI32(out, c->ValueAs<i32>()); },
+ [&](const core::type::U32*) { out << c->ValueAs<u32>() << "u"; },
+ [&](const core::type::F32*) { PrintF32(out, c->ValueAs<f32>()); },
+ [&](const core::type::F16*) { PrintF16(out, c->ValueAs<f16>()); },
+ [&](const core::type::Vector* v) {
+ EmitType(out, v);
- ScopedParen sp(out);
- EmitZeroValue(out, mat->type());
- },
- [&](const core::type::Array*) { out << "{}"; }, //
- [&](const core::type::Struct*) { out << "{}"; }, //
- [&](Default) { TINT_ICE() << "Invalid type for zero emission: " << ty->FriendlyName(); });
-}
+ ScopedParen sp(out);
+ if (auto* splat = c->As<core::constant::Splat>()) {
+ EmitConstant(out, splat->el);
+ return;
+ }
+ emit_values(v->Width());
+ },
+ [&](const core::type::Matrix* m) {
+ EmitType(out, m);
+ ScopedParen sp(out);
+ emit_values(m->columns());
+ },
+ [&](const core::type::Array* a) {
+ EmitType(out, a);
+ out << "{";
+ TINT_DEFER(out << "}");
-std::string Printer::StructName(const core::type::Struct* s) {
- auto name = s->Name().Name();
- if (HasPrefix(name, "__")) {
- name = tint::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
+ if (c->AllZero()) {
+ return;
+ }
+
+ auto count = a->ConstantCount();
+ if (!count) {
+ TINT_ICE() << core::type::Array::kErrExpectedConstantCount;
+ return;
+ }
+ emit_values(*count);
+ },
+ [&](const core::type::Struct* s) {
+ EmitStructType(s);
+ out << StructName(s) << "{";
+ TINT_DEFER(out << "}");
+
+ if (c->AllZero()) {
+ return;
+ }
+
+ auto members = s->Members();
+ for (size_t i = 0; i < members.Length(); i++) {
+ if (i > 0) {
+ out << ", ";
+ }
+ out << "." << members[i]->Name().Name() << "=";
+ EmitConstant(out, c->Index(i));
+ }
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
- return name;
-}
-std::string Printer::UniqueIdentifier(const std::string& prefix /* = "" */) {
- return ir_.symbols.New(prefix).Name();
-}
+ /// Emits the zero value for the given type
+ /// @param out the stream to emit too
+ /// @param ty the type
+ void EmitZeroValue(StringStream& out, const core::type::Type* ty) {
+ Switch(
+ ty, [&](const core::type::Bool*) { out << "false"; }, //
+ [&](const core::type::F16*) { out << "0.0h"; }, //
+ [&](const core::type::F32*) { out << "0.0f"; }, //
+ [&](const core::type::I32*) { out << "0"; }, //
+ [&](const core::type::U32*) { out << "0u"; }, //
+ [&](const core::type::Vector* vec) { EmitZeroValue(out, vec->type()); }, //
+ [&](const core::type::Matrix* mat) {
+ EmitType(out, mat);
-TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+ ScopedParen sp(out);
+ EmitZeroValue(out, mat->type());
+ },
+ [&](const core::type::Array*) { out << "{}"; }, //
+ [&](const core::type::Struct*) { out << "{}"; }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
-std::string Printer::Expr(core::ir::Value* value, PtrKind want_ptr_kind) {
- using ExprAndPtrKind = std::pair<std::string, PtrKind>;
+ /// @param s the structure
+ /// @returns the name of the structure, taking special care of builtin structures that start
+ /// with double underscores. If the structure is a builtin, then the returned name will be a
+ /// unique name without the leading underscores.
+ std::string StructName(const core::type::Struct* s) {
+ auto name = s->Name().Name();
+ if (HasPrefix(name, "__")) {
+ name = tint::GetOrCreate(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
+ }
+ return name;
+ }
- auto [expr, got_ptr_kind] = tint::Switch(
- value,
- [&](core::ir::Constant* c) -> ExprAndPtrKind {
- StringStream str;
- EmitConstant(str, c);
- return {str.str(), PtrKind::kRef};
- },
- [&](Default) -> ExprAndPtrKind {
- auto lookup = bindings_.Find(value);
- if (TINT_UNLIKELY(!lookup)) {
- TINT_ICE() << "Expr(" << (value ? value->TypeInfo().name : "null")
- << ") value has no expression";
- return {};
- }
+ /// @return a new, unique identifier with the given prefix.
+ /// @param prefix optional prefix to apply to the generated identifier. If empty "tint_symbol"
+ /// will be used.
+ std::string UniqueIdentifier(const std::string& prefix /* = "" */) {
+ return ir_.symbols.New(prefix).Name();
+ }
- return std::visit(
- [&](auto&& got) -> ExprAndPtrKind {
- using T = std::decay_t<decltype(got)>;
+ TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
- if constexpr (std::is_same_v<T, VariableValue>) {
- return {got.name.Name(), got.ptr_kind};
- }
+ /// Returns the expression for the given value
+ /// @param value the value to lookup
+ /// @param want_ptr_kind the pointer information for the return
+ /// @returns the string expression
+ std::string Expr(core::ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef) {
+ using ExprAndPtrKind = std::pair<std::string, PtrKind>;
- if constexpr (std::is_same_v<T, InlinedValue>) {
- auto result = ExprAndPtrKind{got.expr, got.ptr_kind};
-
- // Single use (inlined) expression.
- // Mark the bindings_ map entry as consumed.
- *lookup = ConsumedValue{};
- return result;
- }
-
- if constexpr (std::is_same_v<T, ConsumedValue>) {
- TINT_ICE() << "Expr(" << value->TypeInfo().name
- << ") called twice on the same value";
- } else {
- TINT_ICE() << "Expr(" << value->TypeInfo().name << ") has unhandled value";
- }
+ auto [expr, got_ptr_kind] = tint::Switch(
+ value,
+ [&](core::ir::Constant* c) -> ExprAndPtrKind {
+ StringStream str;
+ EmitConstant(str, c);
+ return {str.str(), PtrKind::kRef};
+ },
+ [&](Default) -> ExprAndPtrKind {
+ auto lookup = bindings_.Find(value);
+ if (TINT_UNLIKELY(!lookup)) {
+ TINT_ICE() << "Expr(" << (value ? value->TypeInfo().name : "null")
+ << ") value has no expression";
return {};
- },
- *lookup);
- });
- if (expr.empty()) {
- return "<error>";
+ }
+
+ return std::visit(
+ [&](auto&& got) -> ExprAndPtrKind {
+ using T = std::decay_t<decltype(got)>;
+
+ if constexpr (std::is_same_v<T, VariableValue>) {
+ return {got.name.Name(), got.ptr_kind};
+ }
+
+ if constexpr (std::is_same_v<T, InlinedValue>) {
+ auto result = ExprAndPtrKind{got.expr, got.ptr_kind};
+
+ // Single use (inlined) expression.
+ // Mark the bindings_ map entry as consumed.
+ *lookup = ConsumedValue{};
+ return result;
+ }
+
+ if constexpr (std::is_same_v<T, ConsumedValue>) {
+ TINT_ICE() << "Expr(" << value->TypeInfo().name
+ << ") called twice on the same value";
+ } else {
+ TINT_ICE()
+ << "Expr(" << value->TypeInfo().name << ") has unhandled value";
+ }
+ return {};
+ },
+ *lookup);
+ });
+ if (expr.empty()) {
+ return "<error>";
+ }
+
+ if (value->Type()->Is<core::type::Pointer>()) {
+ return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
+ }
+
+ return expr;
}
- if (value->Type()->Is<core::type::Pointer>()) {
- return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
+ TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+
+ /// Returns the given expression converted to the given pointer kind
+ /// @param in the input expression
+ /// @param got the pointer kind we have
+ /// @param want the pointer kind we want
+ std::string ToPtrKind(const std::string& in, PtrKind got, PtrKind want) {
+ if (want == PtrKind::kRef && got == PtrKind::kPtr) {
+ return "*(" + in + ")";
+ }
+ if (want == PtrKind::kPtr && got == PtrKind::kRef) {
+ return "&(" + in + ")";
+ }
+ return in;
}
- return expr;
-}
+ /// Associates an IR value with a result expression
+ /// @param value the IR value
+ /// @param expr the result expression
+ /// @param ptr_kind defines how pointer values are represented by the expression
+ void Bind(core::ir::Value* value, const std::string& expr, PtrKind ptr_kind = PtrKind::kRef) {
+ TINT_ASSERT(value);
-TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+ if (can_inline_.Remove(value)) {
+ // Value will be inlined at its place of usage.
+ if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
+ return;
+ }
+ } else {
+ auto mod_name = ir_.NameOf(value);
+ if (value->Usages().IsEmpty() && !mod_name.IsValid()) {
+ // Drop phonies.
+ } else {
+ if (mod_name.Name().empty()) {
+ mod_name = ir_.symbols.New("v");
+ }
-std::string Printer::ToPtrKind(const std::string& in, PtrKind got, PtrKind want) {
- if (want == PtrKind::kRef && got == PtrKind::kPtr) {
- return "*(" + in + ")";
- }
- if (want == PtrKind::kPtr && got == PtrKind::kRef) {
- return "&(" + in + ")";
- }
- return in;
-}
+ auto out = Line();
+ EmitType(out, value->Type());
+ out << " const " << mod_name.Name() << " = ";
+ if (value->Type()->Is<core::type::Pointer>()) {
+ out << ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
+ } else {
+ out << expr;
+ }
+ out << ";";
-void Printer::Bind(core::ir::Value* value, const std::string& expr, PtrKind ptr_kind) {
- TINT_ASSERT(value);
-
- if (can_inline_.Remove(value)) {
- // Value will be inlined at its place of usage.
- if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
+ Bind(value, mod_name, PtrKind::kPtr);
+ }
return;
}
- } else {
- auto mod_name = ir_.NameOf(value);
- if (value->Usages().IsEmpty() && !mod_name.IsValid()) {
- // Drop phonies.
- } else {
- if (mod_name.Name().empty()) {
- mod_name = ir_.symbols.New("v");
- }
- auto out = Line();
- EmitType(out, value->Type());
- out << " const " << mod_name.Name() << " = ";
- if (value->Type()->Is<core::type::Pointer>()) {
- out << ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
- } else {
- out << expr;
- }
- out << ";";
-
- Bind(value, mod_name, PtrKind::kPtr);
- }
- return;
- }
-
- TINT_ICE() << "Bind(" << value->TypeInfo().name << ") called twice for same value";
-}
-
-void Printer::Bind(core::ir::Value* value, Symbol name, PtrKind ptr_kind) {
- TINT_ASSERT(value);
-
- bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
- if (TINT_UNLIKELY(!added)) {
TINT_ICE() << "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
-}
-void Printer::MarkInlinable(core::ir::Block* block) {
- // An ordered list of possibly-inlinable values returned by sequenced instructions that have
- // not yet been marked-for or ruled-out-for inlining.
- UniqueVector<core::ir::Value*, 32> pending_resolution;
+ /// Associates an IR value the 'var', 'let' or parameter of the given name
+ /// @param value the IR value
+ /// @param name the name for the value
+ /// @param ptr_kind defines how pointer values are represented by @p expr.
+ void Bind(core::ir::Value* value, Symbol name, PtrKind ptr_kind = PtrKind::kRef) {
+ TINT_ASSERT(value);
- // Walk the instructions of the block starting with the first.
- for (auto* inst : *block) {
- // Is the instruction sequenced?
- bool sequenced = inst->Sequenced();
-
- if (inst->Results().Length() != 1) {
- continue;
- }
-
- // Instruction has a single result value.
- // Check to see if the result of this instruction is a candidate for inlining.
- auto* result = inst->Result();
- // Only values with a single usage can be inlined.
- // Named values are not inlined, as we want to emit the name for a let.
- if (result->Usages().Count() == 1 && !ir_.NameOf(result).IsValid()) {
- if (sequenced) {
- // The value comes from a sequenced instruction. Don't inline.
- } else {
- // The value comes from an unsequenced instruction. Just inline.
- can_inline_.Add(result);
- }
- continue;
+ bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
+ if (TINT_UNLIKELY(!added)) {
+ TINT_ICE() << "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
}
+
+ /// Marks instructions in a block for inlineability
+ /// @param block the block
+ void MarkInlinable(core::ir::Block* block) {
+ // An ordered list of possibly-inlinable values returned by sequenced instructions that have
+ // not yet been marked-for or ruled-out-for inlining.
+ UniqueVector<core::ir::Value*, 32> pending_resolution;
+
+ // Walk the instructions of the block starting with the first.
+ for (auto* inst : *block) {
+ // Is the instruction sequenced?
+ bool sequenced = inst->Sequenced();
+
+ if (inst->Results().Length() != 1) {
+ continue;
+ }
+
+ // Instruction has a single result value.
+ // Check to see if the result of this instruction is a candidate for inlining.
+ auto* result = inst->Result();
+ // Only values with a single usage can be inlined.
+ // Named values are not inlined, as we want to emit the name for a let.
+ if (result->Usages().Count() == 1 && !ir_.NameOf(result).IsValid()) {
+ if (sequenced) {
+ // The value comes from a sequenced instruction. Don't inline.
+ } else {
+ // The value comes from an unsequenced instruction. Just inline.
+ can_inline_.Add(result);
+ }
+ continue;
+ }
+ }
+ }
+};
+} // namespace
+
+Result<std::string> Print(core::ir::Module& module) {
+ return Printer{module}.Generate();
}
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.h b/src/tint/lang/msl/writer/printer/printer.h
index 09f111c..a5218c3 100644
--- a/src/tint/lang/msl/writer/printer/printer.h
+++ b/src/tint/lang/msl/writer/printer/printer.h
@@ -29,231 +29,19 @@
#define SRC_TINT_LANG_MSL_WRITER_PRINTER_PRINTER_H_
#include <string>
-#include <unordered_map>
-#include <unordered_set>
-#include "src/tint/lang/core/ir/module.h"
-#include "src/tint/lang/core/type/texture.h"
-#include "src/tint/utils/diagnostic/diagnostic.h"
-#include "src/tint/utils/generator/text_generator.h"
-#include "src/tint/utils/text/string_stream.h"
+#include "src/tint/utils/result/result.h"
// Forward declarations
namespace tint::core::ir {
-class Binary;
-class ExitIf;
-class If;
-class Let;
-class Load;
-class Return;
-class Unreachable;
-class Var;
+class Module;
} // namespace tint::core::ir
namespace tint::msl::writer {
-/// Implementation class for the MSL generator
-class Printer : public tint::TextGenerator {
- public:
- /// Constructor
- /// @param module the Tint IR module to generate
- explicit Printer(core::ir::Module& module);
- ~Printer() override;
-
- /// @returns success or failure
- tint::Result<SuccessType> Generate();
-
- /// @copydoc tint::TextGenerator::Result
- std::string Result() const override;
-
- private:
- /// Emit the function
- /// @param func the function to emit
- void EmitFunction(core::ir::Function* func);
-
- /// Emit a block
- /// @param block the block to emit
- void EmitBlock(core::ir::Block* block);
- /// Emit the instructions in a block
- /// @param block the block with the instructions to emit
- void EmitBlockInstructions(core::ir::Block* block);
-
- /// Emit an if instruction
- /// @param if_ the if instruction
- void EmitIf(core::ir::If* if_);
- /// Emit an exit-if instruction
- /// @param e the exit-if instruction
- void EmitExitIf(core::ir::ExitIf* e);
-
- /// Emit a let instruction
- /// @param l the let instruction
- void EmitLet(core::ir::Let* l);
- /// Emit a var instruction
- /// @param v the var instruction
- void EmitVar(core::ir::Var* v);
- /// Emit a load instruction
- /// @param l the load instruction
- void EmitLoad(core::ir::Load* l);
-
- /// Emit a return instruction
- /// @param r the return instruction
- void EmitReturn(core::ir::Return* r);
- /// Emit an unreachable instruction
- void EmitUnreachable();
-
- /// Emit a binary instruction
- /// @param b the binary instruction
- void EmitBinary(core::ir::Binary* b);
-
- /// Emit a type
- /// @param out the stream to emit too
- /// @param ty the type to emit
- void EmitType(StringStream& out, const core::type::Type* ty);
-
- /// Handles generating an array declaration
- /// @param out the output stream
- /// @param arr the array to emit
- void EmitArrayType(StringStream& out, const core::type::Array* arr);
- /// Handles generating an atomic declaration
- /// @param out the output stream
- /// @param atomic the atomic to emit
- void EmitAtomicType(StringStream& out, const core::type::Atomic* atomic);
- /// Handles generating a pointer declaration
- /// @param out the output stream
- /// @param ptr the pointer to emit
- void EmitPointerType(StringStream& out, const core::type::Pointer* ptr);
- /// Handles generating a vector declaration
- /// @param out the output stream
- /// @param vec the vector to emit
- void EmitVectorType(StringStream& out, const core::type::Vector* vec);
- /// Handles generating a matrix declaration
- /// @param out the output stream
- /// @param mat the matrix to emit
- void EmitMatrixType(StringStream& out, const core::type::Matrix* mat);
- /// Handles generating a texture declaration
- /// @param out the output stream
- /// @param tex the texture to emit
- void EmitTextureType(StringStream& out, const core::type::Texture* tex);
- /// Handles generating a struct declaration. If the structure has already been emitted, then
- /// this function will simply return without emitting anything.
- /// @param str the struct to generate
- void EmitStructType(const core::type::Struct* str);
-
- /// Handles generating a address space
- /// @param out the output of the type stream
- /// @param sc the address space to generate
- void EmitAddressSpace(StringStream& out, core::AddressSpace sc);
-
- /// Handles core::ir::Constant values
- /// @param out the stream to write the constant too
- /// @param c the constant to emit
- void EmitConstant(StringStream& out, core::ir::Constant* c);
- /// Handles core::constant::Value values
- /// @param out the stream to write the constant too
- /// @param c the constant to emit
- void EmitConstant(StringStream& out, const core::constant::Value* c);
-
- /// Emits the zero value for the given type
- /// @param out the stream to emit too
- /// @param ty the type
- void EmitZeroValue(StringStream& out, const core::type::Type* ty);
-
- /// @returns the name of the templated `tint_array` helper type, generating it if needed
- const std::string& ArrayTemplateName();
-
- /// @param s the structure
- /// @returns the name of the structure, taking special care of builtin structures that start
- /// with double underscores. If the structure is a builtin, then the returned name will be a
- /// unique name without the leading underscores.
- std::string StructName(const core::type::Struct* s);
-
- /// @return a new, unique identifier with the given prefix.
- /// @param prefix optional prefix to apply to the generated identifier. If empty "tint_symbol"
- /// will be used.
- std::string UniqueIdentifier(const std::string& prefix = "");
-
- /// Map of builtin structure to unique generated name
- std::unordered_map<const core::type::Struct*, std::string> builtin_struct_names_;
-
- core::ir::Module& ir_;
-
- /// The buffer holding preamble text
- TextBuffer preamble_buffer_;
-
- /// Unique name of the 'TINT_INVARIANT' preprocessor define.
- /// Non-empty only if an invariant attribute has been generated.
- std::string invariant_define_name_;
-
- std::unordered_set<const core::type::Struct*> emitted_structs_;
-
- /// The current function being emitted
- core::ir::Function* current_function_ = nullptr;
- /// The current block being emitted
- core::ir::Block* current_block_ = nullptr;
-
- /// Unique name of the tint_array<T, N> template.
- /// Non-empty only if the template has been generated.
- std::string array_template_name_;
-
- /// The representation for an IR pointer type
- enum class PtrKind {
- kPtr, // IR pointer is represented in a pointer
- kRef, // IR pointer is represented in a reference
- };
-
- /// The structure for a value held by a 'let', 'var' or parameter.
- struct VariableValue {
- Symbol name; // Name of the variable
- PtrKind ptr_kind = PtrKind::kRef;
- };
-
- /// The structure for an inlined value
- struct InlinedValue {
- std::string expr;
- PtrKind ptr_kind = PtrKind::kRef;
- };
-
- /// Empty struct used as a sentinel value to indicate that an string expression has been
- /// consumed by its single place of usage. Attempting to use this value a second time should
- /// result in an ICE.
- struct ConsumedValue {};
-
- using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
-
- /// IR values to their representation
- Hashmap<core::ir::Value*, ValueBinding, 32> bindings_;
-
- /// Values that can be inlined.
- Hashset<core::ir::Value*, 64> can_inline_;
-
- /// Returns the expression for the given value
- /// @param value the value to lookup
- /// @param want_ptr_kind the pointer information for the return
- /// @returns the string expression
- std::string Expr(core::ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef);
-
- /// Returns the given expression converted to the given pointer kind
- /// @param in the input expression
- /// @param got the pointer kind we have
- /// @param want the pointer kind we want
- std::string ToPtrKind(const std::string& in, PtrKind got, PtrKind want);
-
- /// Associates an IR value with a result expression
- /// @param value the IR value
- /// @param expr the result expression
- /// @param ptr_kind defines how pointer values are represented by the expression
- void Bind(core::ir::Value* value, const std::string& expr, PtrKind ptr_kind = PtrKind::kRef);
-
- /// Associates an IR value the 'var', 'let' or parameter of the given name
- /// @param value the IR value
- /// @param name the name for the value
- /// @param ptr_kind defines how pointer values are represented by @p expr.
- void Bind(core::ir::Value* value, Symbol name, PtrKind ptr_kind = PtrKind::kRef);
-
- /// Marks instructions in a block for inlineability
- /// @param block the block
- void MarkInlinable(core::ir::Block* block);
-};
+/// @returns the generated MSL shader on success, or failure
+/// @param module the Tint IR module to generate
+Result<std::string> Print(core::ir::Module& module);
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/writer.cc b/src/tint/lang/msl/writer/writer.cc
index 8f854df..78f3040 100644
--- a/src/tint/lang/msl/writer/writer.cc
+++ b/src/tint/lang/msl/writer/writer.cc
@@ -69,12 +69,11 @@
}
// Generate the MSL code.
- auto impl = std::make_unique<Printer>(ir);
- auto result = impl->Generate();
+ auto result = Print(ir);
if (!result) {
return result.Failure();
}
- output.msl = impl->Result();
+ output.msl = result.Get();
#else
return Failure{"use_tint_ir requires building with TINT_BUILD_WGSL_READER"};
#endif
diff --git a/src/tint/lang/spirv/reader/ast_lower/atomics.cc b/src/tint/lang/spirv/reader/ast_lower/atomics.cc
index 09b5fe6..fa1f371 100644
--- a/src/tint/lang/spirv/reader/ast_lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/atomics.cc
@@ -236,11 +236,8 @@
return b.ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()),
ptr->Access());
},
- [&](const core::type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
- [&](Default) {
- TINT_ICE() << "unhandled type: " << ty->FriendlyName();
- return ast::Type{};
- });
+ [&](const core::type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ReplaceLoadsAndStores() {
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
index 4875d2c..2c7d268 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
@@ -2054,11 +2054,8 @@
const bool value =
spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value();
return TypedExpression{ty_.Bool(), create<ast::BoolLiteralExpression>(source, value)};
- },
- [&](Default) {
- Fail() << "expected scalar constant";
- return TypedExpression{};
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
const ast::Expression* ASTParser::MakeNullValue(const Type* type) {
@@ -2100,11 +2097,8 @@
}
return builder_.Call(Source{}, original_type->Build(builder_),
std::move(ast_components));
- },
- [&](Default) {
- Fail() << "can't make null value for type: " << type->TypeInfo().name;
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
TypedExpression ASTParser::MakeNullExpression(const Type* type) {
diff --git a/src/tint/lang/spirv/reader/ast_parser/helper_test.cc b/src/tint/lang/spirv/reader/ast_parser/helper_test.cc
index e0a5cd2..a417c51 100644
--- a/src/tint/lang/spirv/reader/ast_parser/helper_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/helper_test.cc
@@ -88,10 +88,8 @@
}
return writer.Result();
},
- [&](const ast::Identifier* ident) { return ident->symbol.Name(); },
- [&](Default) {
- return "<unhandled AST node type " + std::string(node->TypeInfo().name) + ">";
- });
+ [&](const ast::Identifier* ident) { return ident->symbol.Name(); }, //
+ TINT_ICE_ON_NO_MATCH);
}
} // namespace tint::spirv::reader::ast_parser::test
diff --git a/src/tint/lang/spirv/writer/BUILD.bazel b/src/tint/lang/spirv/writer/BUILD.bazel
index a03bea0..1525211 100644
--- a/src/tint/lang/spirv/writer/BUILD.bazel
+++ b/src/tint/lang/spirv/writer/BUILD.bazel
@@ -50,12 +50,8 @@
"//src/tint/api/options",
"//src/tint/lang/core",
"//src/tint/lang/core/constant",
- "//src/tint/lang/core/intrinsic",
"//src/tint/lang/core/ir",
"//src/tint/lang/core/type",
- "//src/tint/lang/spirv",
- "//src/tint/lang/spirv/intrinsic",
- "//src/tint/lang/spirv/ir",
"//src/tint/lang/wgsl",
"//src/tint/lang/wgsl/ast",
"//src/tint/lang/wgsl/program",
@@ -128,9 +124,6 @@
"//src/tint/lang/core/intrinsic",
"//src/tint/lang/core/ir",
"//src/tint/lang/core/type",
- "//src/tint/lang/spirv",
- "//src/tint/lang/spirv/intrinsic",
- "//src/tint/lang/spirv/ir",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
"//src/tint/utils/ice",
diff --git a/src/tint/lang/spirv/writer/BUILD.cmake b/src/tint/lang/spirv/writer/BUILD.cmake
index 7f67a5d..767a497 100644
--- a/src/tint/lang/spirv/writer/BUILD.cmake
+++ b/src/tint/lang/spirv/writer/BUILD.cmake
@@ -58,12 +58,8 @@
tint_api_options
tint_lang_core
tint_lang_core_constant
- tint_lang_core_intrinsic
tint_lang_core_ir
tint_lang_core_type
- tint_lang_spirv
- tint_lang_spirv_intrinsic
- tint_lang_spirv_ir
tint_lang_wgsl
tint_lang_wgsl_ast
tint_lang_wgsl_program
@@ -142,9 +138,6 @@
tint_lang_core_intrinsic
tint_lang_core_ir
tint_lang_core_type
- tint_lang_spirv
- tint_lang_spirv_intrinsic
- tint_lang_spirv_ir
tint_utils_containers
tint_utils_diagnostic
tint_utils_ice
diff --git a/src/tint/lang/spirv/writer/BUILD.gn b/src/tint/lang/spirv/writer/BUILD.gn
index 552d063..d9cefee 100644
--- a/src/tint/lang/spirv/writer/BUILD.gn
+++ b/src/tint/lang/spirv/writer/BUILD.gn
@@ -53,12 +53,8 @@
"${tint_src_dir}/api/options",
"${tint_src_dir}/lang/core",
"${tint_src_dir}/lang/core/constant",
- "${tint_src_dir}/lang/core/intrinsic",
"${tint_src_dir}/lang/core/ir",
"${tint_src_dir}/lang/core/type",
- "${tint_src_dir}/lang/spirv",
- "${tint_src_dir}/lang/spirv/intrinsic",
- "${tint_src_dir}/lang/spirv/ir",
"${tint_src_dir}/lang/wgsl",
"${tint_src_dir}/lang/wgsl/ast",
"${tint_src_dir}/lang/wgsl/program",
@@ -130,9 +126,6 @@
"${tint_src_dir}/lang/core/intrinsic",
"${tint_src_dir}/lang/core/ir",
"${tint_src_dir}/lang/core/type",
- "${tint_src_dir}/lang/spirv",
- "${tint_src_dir}/lang/spirv/intrinsic",
- "${tint_src_dir}/lang/spirv/ir",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
"${tint_src_dir}/utils/ice",
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.cc b/src/tint/lang/spirv/writer/ast_printer/builder.cc
index ee22481..734c380 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.cc
@@ -541,11 +541,8 @@
[&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
[&](const ast::IdentifierExpression* i) { return GenerateIdentifierExpression(i); },
[&](const ast::LiteralExpression* l) { return GenerateLiteralIfNeeded(l); },
- [&](const ast::UnaryOpExpression* u) { return GenerateUnaryOpExpression(u); },
- [&](Default) {
- TINT_ICE() << "unknown expression type: " + std::string(expr->TypeInfo().name);
- return 0;
- });
+ [&](const ast::UnaryOpExpression* u) { return GenerateUnaryOpExpression(u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
@@ -858,11 +855,8 @@
},
[&](const ast::InternalAttribute*) {
return true; // ignored
- },
- [&](Default) {
- TINT_ICE() << "unknown attribute";
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
}
@@ -1049,11 +1043,8 @@
info->source_id = result_id;
info->source_type = expr_type;
return true;
- },
- [&](Default) {
- TINT_ICE() << "unhandled member index type: " << expr_sem->TypeInfo().name;
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
uint32_t Builder::GenerateAccessorExpression(const ast::AccessorExpression* expr) {
@@ -1097,11 +1088,8 @@
},
[&](const ast::MemberAccessorExpression* member) {
return GenerateMemberAccessor(member, &info);
- },
- [&](Default) {
- TINT_ICE() << "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
}
@@ -1620,8 +1608,8 @@
constant.value.f16 = {f16(static_cast<float>(f->value)).BitsRepresentation()};
return;
}
- },
- [&](Default) { TINT_ICE() << "unknown literal type"; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (has_error()) {
return false;
@@ -1699,11 +1687,8 @@
}
return composite(count.value());
},
- [&](const core::type::Struct* s) { return composite(s->Members().Length()); },
- [&](Default) {
- TINT_ICE() << "unhandled constant type: " + ty->FriendlyName();
- return 0;
- });
+ [&](const core::type::Struct* s) { return composite(s->Members().Length()); }, //
+ TINT_ICE_ON_NO_MATCH);
}
uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
@@ -2232,11 +2217,8 @@
},
[&](const sem::ValueConstructor*) {
return GenerateValueConstructorOrConversion(call, nullptr);
- },
- [&](Default) {
- TINT_ICE() << "unhandled call target: " << target->TypeInfo().name;
- return 0;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
uint32_t Builder::GenerateFunctionCall(const sem::Call* call, const sem::Function* fn) {
@@ -3639,11 +3621,8 @@
[&](const ast::VariableDeclStatement* v) { return GenerateVariableDeclStatement(v); },
[&](const ast::ConstAssert*) {
return true; // Not emitted
- },
- [&](Default) {
- TINT_ICE() << "unknown statement type: " + std::string(stmt->TypeInfo().name);
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
bool Builder::GenerateVariableDeclStatement(const ast::VariableDeclStatement* stmt) {
@@ -3767,11 +3746,8 @@
core::type::SamplerKind::kSampler)] = id;
}
return true;
- },
- [&](Default) {
- TINT_ICE() << "unable to convert type: " + type->FriendlyName();
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return 0;
@@ -3843,8 +3819,8 @@
},
[&](const core::type::SampledTexture* t) { return GenerateTypeIfNeeded(t->type()); },
[&](const core::type::MultisampledTexture* t) { return GenerateTypeIfNeeded(t->type()); },
- [&](const core::type::StorageTexture* t) { return GenerateTypeIfNeeded(t->type()); },
- [&](Default) { return 0u; });
+ [&](const core::type::StorageTexture* t) { return GenerateTypeIfNeeded(t->type()); }, //
+ TINT_ICE_ON_NO_MATCH);
if (type_id == 0u) {
return false;
}
diff --git a/src/tint/lang/spirv/writer/ast_raise/merge_return.cc b/src/tint/lang/spirv/writer/ast_raise/merge_return.cc
index b36faea..07216c4 100644
--- a/src/tint/lang/spirv/writer/ast_raise/merge_return.cc
+++ b/src/tint/lang/spirv/writer/ast_raise/merge_return.cc
@@ -154,8 +154,8 @@
[&](const ast::WhileStatement* w) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(w->body);
- },
- [&](Default) { TINT_ICE() << "unhandled statement type"; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ProcessBlock(const ast::BlockStatement* block) {
diff --git a/src/tint/lang/spirv/writer/common/BUILD.bazel b/src/tint/lang/spirv/writer/common/BUILD.bazel
index 3ad9cfa..7a3f041 100644
--- a/src/tint/lang/spirv/writer/common/BUILD.bazel
+++ b/src/tint/lang/spirv/writer/common/BUILD.bazel
@@ -96,9 +96,6 @@
"//src/tint/lang/core/intrinsic",
"//src/tint/lang/core/ir",
"//src/tint/lang/core/type",
- "//src/tint/lang/spirv",
- "//src/tint/lang/spirv/intrinsic",
- "//src/tint/lang/spirv/ir",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
"//src/tint/utils/ice",
diff --git a/src/tint/lang/spirv/writer/common/BUILD.cmake b/src/tint/lang/spirv/writer/common/BUILD.cmake
index ce95129..94f5a1b 100644
--- a/src/tint/lang/spirv/writer/common/BUILD.cmake
+++ b/src/tint/lang/spirv/writer/common/BUILD.cmake
@@ -101,9 +101,6 @@
tint_lang_core_intrinsic
tint_lang_core_ir
tint_lang_core_type
- tint_lang_spirv
- tint_lang_spirv_intrinsic
- tint_lang_spirv_ir
tint_utils_containers
tint_utils_diagnostic
tint_utils_ice
diff --git a/src/tint/lang/spirv/writer/common/BUILD.gn b/src/tint/lang/spirv/writer/common/BUILD.gn
index 75d6cdc..53c6f01 100644
--- a/src/tint/lang/spirv/writer/common/BUILD.gn
+++ b/src/tint/lang/spirv/writer/common/BUILD.gn
@@ -98,9 +98,6 @@
"${tint_src_dir}/lang/core/intrinsic",
"${tint_src_dir}/lang/core/ir",
"${tint_src_dir}/lang/core/type",
- "${tint_src_dir}/lang/spirv",
- "${tint_src_dir}/lang/spirv/intrinsic",
- "${tint_src_dir}/lang/spirv/ir",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
"${tint_src_dir}/utils/ice",
diff --git a/src/tint/lang/spirv/writer/common/helper_test.h b/src/tint/lang/spirv/writer/common/helper_test.h
index fc616ce..f04c95c 100644
--- a/src/tint/lang/spirv/writer/common/helper_test.h
+++ b/src/tint/lang/spirv/writer/common/helper_test.h
@@ -89,8 +89,6 @@
template <typename BASE>
class SpirvWriterTestHelperBase : public BASE {
public:
- SpirvWriterTestHelperBase() : writer_(mod, false) {}
-
/// The test module.
core::ir::Module mod;
/// The test builder.
@@ -99,51 +97,48 @@
core::type::Manager& ty{mod.Types()};
protected:
- /// The SPIR-V writer.
- Printer writer_;
-
/// Errors produced during codegen or SPIR-V validation.
std::string err_;
/// SPIR-V output.
std::string output_;
+ /// The generated SPIR-V
+ writer::Module spirv_;
+
/// @returns the error string from the validation
std::string Error() const { return err_; }
- /// Run the specified writer on the IR module and validate the result.
- /// @param writer the writer to use for SPIR-V generation
+ /// Run the printer on the IR module and validate the result.
/// @param options the optional writer options to use when raising the IR
+ /// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
+ /// storage class with OpConstantNull
/// @returns true if generation and validation succeeded
- bool Generate(Printer& writer, Options options = {}) {
+ bool Generate(Options options = {}, bool zero_init_workgroup_memory = false) {
auto raised = raise::Raise(mod, options);
if (!raised) {
err_ = raised.Failure().reason.str();
return false;
}
- auto spirv = writer.Generate();
+ auto spirv = PrintModule(mod, zero_init_workgroup_memory);
if (!spirv) {
err_ = spirv.Failure().reason.str();
return false;
}
- output_ = Disassemble(spirv.Get(), SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
- SPV_BINARY_TO_TEXT_OPTION_INDENT |
- SPV_BINARY_TO_TEXT_OPTION_COMMENT);
+ output_ = Disassemble(spirv->Code(), SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
+ SPV_BINARY_TO_TEXT_OPTION_INDENT |
+ SPV_BINARY_TO_TEXT_OPTION_COMMENT);
- if (!Validate(spirv.Get())) {
+ if (!Validate(spirv->Code())) {
return false;
}
+ spirv_ = std::move(spirv.Get());
return true;
}
- /// Run the writer on the IR module and validate the result.
- /// @param options the optional writer options to use when raising the IR
- /// @returns true if generation and validation succeeded
- bool Generate(Options options = {}) { return Generate(writer_, options); }
-
/// Validate the generated SPIR-V using the SPIR-V Tools Validator.
/// @param binary the SPIR-V binary module to validate
/// @returns true if validation succeeded, false otherwise
@@ -180,7 +175,7 @@
}
/// @returns the disassembled types from the generated module.
- std::string DumpTypes() { return DumpInstructions(writer_.Module().Types()); }
+ std::string DumpTypes() { return DumpInstructions(spirv_.Types()); }
/// Helper to make a scalar type corresponding to the element type `type`.
/// @param type the element type
diff --git a/src/tint/lang/spirv/writer/common/module.cc b/src/tint/lang/spirv/writer/common/module.cc
index 622ce65..b93ccd7 100644
--- a/src/tint/lang/spirv/writer/common/module.cc
+++ b/src/tint/lang/spirv/writer/common/module.cc
@@ -43,10 +43,18 @@
} // namespace
-Module::Module() {}
+Module::Module() = default;
+
+Module::Module(const Module&) = default;
+
+Module::Module(Module&&) = default;
Module::~Module() = default;
+Module& Module::operator=(const Module& other) = default;
+
+Module& Module::operator=(Module&& other) = default;
+
uint32_t Module::TotalSize() const {
// The 5 covers the magic, version, generator, id bound and reserved.
uint32_t size = 5;
diff --git a/src/tint/lang/spirv/writer/common/module.h b/src/tint/lang/spirv/writer/common/module.h
index bf2ffb1..88f0b87 100644
--- a/src/tint/lang/spirv/writer/common/module.h
+++ b/src/tint/lang/spirv/writer/common/module.h
@@ -45,9 +45,27 @@
/// Constructor
Module();
+ /// Copy constructor
+ /// @param other the other Module to copy
+ Module(const Module& other);
+
+ /// Move constructor
+ /// @param other the other Module to move
+ Module(Module&& other);
+
/// Destructor
~Module();
+ /// Copy-assignment operator
+ /// @param other the other Module to copy
+ /// @returns this Module
+ Module& operator=(const Module& other);
+
+ /// Move-assignment operator
+ /// @param other the other Module to move
+ /// @returns this Module
+ Module& operator=(Module&& other);
+
/// @returns the number of uint32_t's needed to make up the results
uint32_t TotalSize() const;
@@ -155,6 +173,9 @@
/// @returns the functions
const std::vector<Function>& Functions() const { return functions_; }
+ /// @returns the SPIR-V code as a vector of uint32_t
+ std::vector<uint32_t>& Code() { return code_; }
+
private:
uint32_t next_id_ = 1;
InstructionList capabilities_;
@@ -169,6 +190,7 @@
std::vector<Function> functions_;
Hashset<uint32_t, 8> capability_set_;
Hashset<std::string, 8> extension_set_;
+ std::vector<uint32_t> code_;
};
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/constant_test.cc b/src/tint/lang/spirv/writer/constant_test.cc
index a2f3980..9f24475 100644
--- a/src/tint/lang/spirv/writer/constant_test.cc
+++ b/src/tint/lang/spirv/writer/constant_test.cc
@@ -35,154 +35,187 @@
using namespace tint::core::fluent_types; // NOLINT
TEST_F(SpirvWriterTest, Constant_Bool) {
- writer_.Constant(b.Constant(true));
- writer_.Constant(b.Constant(false));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", true);
+ b.Var<private_, read_write>("v", false);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%true = OpConstantTrue %bool");
EXPECT_INST("%false = OpConstantFalse %bool");
}
TEST_F(SpirvWriterTest, Constant_I32) {
- writer_.Constant(b.Constant(i32(42)));
- writer_.Constant(b.Constant(i32(-1)));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", i32(42));
+ b.Var<private_, read_write>("v", i32(-1));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%int_42 = OpConstant %int 42");
EXPECT_INST("%int_n1 = OpConstant %int -1");
}
TEST_F(SpirvWriterTest, Constant_U32) {
- writer_.Constant(b.Constant(u32(42)));
- writer_.Constant(b.Constant(u32(4000000000)));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", u32(42));
+ b.Var<private_, read_write>("v", u32(4000000000));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%uint_42 = OpConstant %uint 42");
EXPECT_INST("%uint_4000000000 = OpConstant %uint 4000000000");
}
TEST_F(SpirvWriterTest, Constant_F32) {
- writer_.Constant(b.Constant(f32(42)));
- writer_.Constant(b.Constant(f32(-1)));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", f32(42));
+ b.Var<private_, read_write>("v", f32(-1));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%float_42 = OpConstant %float 42");
EXPECT_INST("%float_n1 = OpConstant %float -1");
}
TEST_F(SpirvWriterTest, Constant_F16) {
- writer_.Constant(b.Constant(f16(42)));
- writer_.Constant(b.Constant(f16(-1)));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", f16(42));
+ b.Var<private_, read_write>("v", f16(-1));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%half_0x1_5p_5 = OpConstant %half 0x1.5p+5");
EXPECT_INST("%half_n0x1p_0 = OpConstant %half -0x1p+0");
}
TEST_F(SpirvWriterTest, Constant_Vec4Bool) {
- writer_.Constant(b.Composite(ty.vec4<bool>(), true, false, false, true));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", b.Composite(ty.vec4<bool>(), true, false, false, true));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %v4bool %true %false %false %true");
+ EXPECT_INST(" = OpConstantComposite %v4bool %true %false %false %true");
}
TEST_F(SpirvWriterTest, Constant_Vec2i) {
- writer_.Constant(b.Composite(ty.vec2<i32>(), 42_i, -1_i));
+ b.Append(b.ir.root_block,
+ [&] { b.Var<private_, read_write>("v", b.Composite(ty.vec2<i32>(), 42_i, -1_i)); });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %v2int %int_42 %int_n1");
+ EXPECT_INST(" = OpConstantComposite %v2int %int_42 %int_n1");
}
TEST_F(SpirvWriterTest, Constant_Vec3u) {
- writer_.Constant(b.Composite(ty.vec3<u32>(), 42_u, 0_u, 4000000000_u));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", b.Composite(ty.vec3<u32>(), 42_u, 0_u, 4000000000_u));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %v3uint %uint_42 %uint_0 %uint_4000000000");
+ EXPECT_INST(" = OpConstantComposite %v3uint %uint_42 %uint_0 %uint_4000000000");
}
TEST_F(SpirvWriterTest, Constant_Vec4f) {
- writer_.Constant(b.Composite(ty.vec4<f32>(), 42_f, 0_f, 0.25_f, -1_f));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", b.Composite(ty.vec4<f32>(), 42_f, 0_f, 0.25_f, -1_f));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %v4float %float_42 %float_0 %float_0_25 %float_n1");
+ EXPECT_INST(" = OpConstantComposite %v4float %float_42 %float_0 %float_0_25 %float_n1");
}
TEST_F(SpirvWriterTest, Constant_Vec2h) {
- writer_.Constant(b.Composite(ty.vec2<f16>(), 42_h, 0.25_h));
+ b.Append(b.ir.root_block,
+ [&] { b.Var<private_, read_write>("v", b.Composite(ty.vec2<f16>(), 42_h, 0.25_h)); });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %v2half %half_0x1_5p_5 %half_0x1pn2");
+ EXPECT_INST(" = OpConstantComposite %v2half %half_0x1_5p_5 %half_0x1pn2");
}
TEST_F(SpirvWriterTest, Constant_Mat2x3f) {
- writer_.Constant(b.Composite(ty.mat2x3<f32>(), //
- b.Composite(ty.vec3<f32>(), 42_f, -1_f, 0.25_f),
- b.Composite(ty.vec3<f32>(), -42_f, 0_f, -0.25_f)));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v",
+ b.Composite(ty.mat2x3<f32>(), //
+ b.Composite(ty.vec3<f32>(), 42_f, -1_f, 0.25_f),
+ b.Composite(ty.vec3<f32>(), -42_f, 0_f, -0.25_f)));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(R"(
%float_42 = OpConstant %float 42
%float_n1 = OpConstant %float -1
%float_0_25 = OpConstant %float 0.25
- %5 = OpConstantComposite %v3float %float_42 %float_n1 %float_0_25
+ %7 = OpConstantComposite %v3float %float_42 %float_n1 %float_0_25
%float_n42 = OpConstant %float -42
%float_0 = OpConstant %float 0
%float_n0_25 = OpConstant %float -0.25
- %9 = OpConstantComposite %v3float %float_n42 %float_0 %float_n0_25
- %1 = OpConstantComposite %mat2v3float %5 %9
+ %11 = OpConstantComposite %v3float %float_n42 %float_0 %float_n0_25
+ %6 = OpConstantComposite %mat2v3float %7 %11
)");
}
TEST_F(SpirvWriterTest, Constant_Mat4x2h) {
- writer_.Constant(b.Composite(ty.mat4x2<f16>(), //
- b.Composite(ty.vec2<f16>(), 42_h, -1_h), //
- b.Composite(ty.vec2<f16>(), 0_h, 0.25_h), //
- b.Composite(ty.vec2<f16>(), -42_h, 1_h), //
- b.Composite(ty.vec2<f16>(), 0.5_h, f16(-0))));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", b.Composite(ty.mat4x2<f16>(), //
+ b.Composite(ty.vec2<f16>(), 42_h, -1_h), //
+ b.Composite(ty.vec2<f16>(), 0_h, 0.25_h), //
+ b.Composite(ty.vec2<f16>(), -42_h, 1_h), //
+ b.Composite(ty.vec2<f16>(), 0.5_h, f16(-0))));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(R"(
%half_0x1_5p_5 = OpConstant %half 0x1.5p+5
%half_n0x1p_0 = OpConstant %half -0x1p+0
- %5 = OpConstantComposite %v2half %half_0x1_5p_5 %half_n0x1p_0
+ %7 = OpConstantComposite %v2half %half_0x1_5p_5 %half_n0x1p_0
%half_0x0p_0 = OpConstant %half 0x0p+0
%half_0x1pn2 = OpConstant %half 0x1p-2
- %8 = OpConstantComposite %v2half %half_0x0p_0 %half_0x1pn2
+ %10 = OpConstantComposite %v2half %half_0x0p_0 %half_0x1pn2
%half_n0x1_5p_5 = OpConstant %half -0x1.5p+5
%half_0x1p_0 = OpConstant %half 0x1p+0
- %11 = OpConstantComposite %v2half %half_n0x1_5p_5 %half_0x1p_0
+ %13 = OpConstantComposite %v2half %half_n0x1_5p_5 %half_0x1p_0
%half_0x1pn1 = OpConstant %half 0x1p-1
- %14 = OpConstantComposite %v2half %half_0x1pn1 %half_0x0p_0
- %1 = OpConstantComposite %mat4v2half %5 %8 %11 %14
+ %16 = OpConstantComposite %v2half %half_0x1pn1 %half_0x0p_0
+ %6 = OpConstantComposite %mat4v2half %7 %10 %13 %16
)");
}
TEST_F(SpirvWriterTest, Constant_Array_I32) {
- writer_.Constant(b.Composite(ty.array<i32, 4>(), 1_i, 2_i, 3_i, 4_i));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", b.Composite(ty.array<i32, 4>(), 1_i, 2_i, 3_i, 4_i));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4");
+ EXPECT_INST(" = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4");
}
TEST_F(SpirvWriterTest, Constant_Array_Array_I32) {
- auto* inner = b.Composite(ty.array<i32, 4>(), 1_i, 2_i, 3_i, 4_i);
- writer_.Constant(b.Composite(ty.array(ty.array<i32, 4>(), 4), inner, inner, inner, inner));
+ b.Append(b.ir.root_block, [&] {
+ auto* inner = b.Composite(ty.array<i32, 4>(), 1_i, 2_i, 3_i, 4_i);
+ b.Var<private_, read_write>(
+ "v", b.Composite(ty.array(ty.array<i32, 4>(), 4), inner, inner, inner, inner));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(R"(
- %7 = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4
- %1 = OpConstantComposite %_arr__arr_int_uint_4_uint_4 %7 %7 %7 %7
+ %9 = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4
+ %8 = OpConstantComposite %_arr__arr_int_uint_4_uint_4 %9 %9 %9 %9
)");
}
TEST_F(SpirvWriterTest, Constant_Array_LargeAllZero) {
- writer_.Constant(b.Zero(ty.array<i32, 65535>()));
+ b.Append(b.ir.root_block,
+ [&] { b.Var<private_, read_write>("v", b.Zero(ty.array<i32, 65535>())); });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantNull %_arr_int_uint_65535");
+ EXPECT_INST(" = OpConstantNull %_arr_int_uint_65535");
}
TEST_F(SpirvWriterTest, Constant_Struct) {
- auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
- {mod.symbols.New("a"), ty.i32()},
- {mod.symbols.New("b"), ty.u32()},
- {mod.symbols.New("c"), ty.f32()},
- });
- writer_.Constant(b.Composite(str_ty, 1_i, 2_u, 3_f));
+ b.Append(b.ir.root_block, [&] {
+ auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.u32()},
+ {mod.symbols.New("c"), ty.f32()},
+ });
+ b.Var<private_, read_write>("v", b.Composite(str_ty, 1_i, 2_u, 3_f));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpConstantComposite %MyStruct %int_1 %uint_2 %float_3");
+ EXPECT_INST(" = OpConstantComposite %MyStruct %int_1 %uint_2 %float_3");
}
// Test that we do not emit the same constant more than once.
TEST_F(SpirvWriterTest, Constant_Deduplicate) {
- writer_.Constant(b.Constant(42_i));
- writer_.Constant(b.Constant(42_i));
- writer_.Constant(b.Constant(42_i));
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, read_write>("v", 42_i);
+ b.Var<private_, read_write>("v", 42_i);
+ b.Var<private_, read_write>("v", 42_i);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%int_42 = OpConstant %int 42");
}
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 0bd3b7a..a7c86b8 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -31,8 +31,12 @@
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
+
+#include "src/tint/lang/core/address_space.h"
+#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/core/constant/scalar.h"
#include "src/tint/lang/core/constant/splat.h"
+#include "src/tint/lang/core/constant/value.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
@@ -40,6 +44,8 @@
#include "src/tint/lang/core/ir/block.h"
#include "src/tint/lang/core/ir/block_param.h"
#include "src/tint/lang/core/ir/break_if.h"
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
@@ -66,6 +72,7 @@
#include "src/tint/lang/core/ir/user_call.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/ir/var.h"
+#include "src/tint/lang/core/texel_format.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/atomic.h"
#include "src/tint/lang/core/type/bool.h"
@@ -86,13 +93,21 @@
#include "src/tint/lang/core/type/u32.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/core/type/void.h"
+#include "src/tint/lang/spirv/ir/builtin_call.h"
#include "src/tint/lang/spirv/ir/literal_operand.h"
#include "src/tint/lang/spirv/type/sampled_image.h"
#include "src/tint/lang/spirv/writer/ast_printer/ast_printer.h"
+#include "src/tint/lang/spirv/writer/common/binary_writer.h"
+#include "src/tint/lang/spirv/writer/common/function.h"
#include "src/tint/lang/spirv/writer/common/module.h"
#include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
+#include "src/tint/utils/containers/hashmap.h"
+#include "src/tint/utils/containers/vector.h"
+#include "src/tint/utils/diagnostic/diagnostic.h"
#include "src/tint/utils/macros/scoped_assignment.h"
+#include "src/tint/utils/result/result.h"
#include "src/tint/utils/rtti/switch.h"
+#include "src/tint/utils/symbol/symbol.h"
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
@@ -162,1871 +177,2090 @@
[&](Default) { return ty; });
}
-} // namespace
+/// PIMPL class for SPIR-V writer
+class Printer {
+ public:
+ /// Constructor
+ /// @param module the Tint IR module to generate
+ /// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
+ /// storage class with OpConstantNull
+ Printer(core::ir::Module& module, bool zero_init_workgroup_memory)
+ : ir_(module), b_(module), zero_init_workgroup_memory_(zero_init_workgroup_memory) {}
-Printer::Printer(core::ir::Module& module, bool zero_init_workgroup_mem)
- : ir_(module), b_(module), zero_init_workgroup_memory_(zero_init_workgroup_mem) {}
+ /// @returns the generated SPIR-V code on success, or failure
+ Result<std::vector<uint32_t>> Code() {
+ if (auto res = Generate(); !res) {
+ return res.Failure();
+ }
-Result<std::vector<uint32_t>> Printer::Generate() {
- auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "SPIR-V writer");
- if (!valid) {
- return valid.Failure();
+ // Serialize the module into binary SPIR-V.
+ BinaryWriter writer;
+ writer.WriteHeader(module_.IdBound(), kWriterVersion);
+ writer.WriteModule(module_);
+ return std::move(writer.Result());
}
- module_.PushCapability(SpvCapabilityShader);
- module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
- U32Operand(SpvMemoryModelGLSL450)});
+ /// @returns the generated SPIR-V module on success, or failure
+ Result<writer::Module> Module() {
+ if (auto res = Generate(); !res) {
+ return res.Failure();
+ }
- // Emit module-scope declarations.
- EmitRootBlock(ir_.root_block);
-
- // Emit functions.
- for (auto* func : ir_.functions) {
- EmitFunction(func);
+ // Serialize the module into binary SPIR-V.
+ BinaryWriter writer;
+ writer.WriteHeader(module_.IdBound(), kWriterVersion);
+ writer.WriteModule(module_);
+ module_.Code() = std::move(writer.Result());
+ return module_;
}
- // Serialize the module into binary SPIR-V.
- BinaryWriter writer;
- writer.WriteHeader(module_.IdBound(), kWriterVersion);
- writer.WriteModule(module_);
- return std::move(writer.Result());
-}
+ private:
+ core::ir::Module& ir_;
+ core::ir::Builder b_;
+ writer::Module module_;
+ BinaryWriter writer_;
-uint32_t Printer::Builtin(core::BuiltinValue builtin, core::AddressSpace addrspace) {
- switch (builtin) {
- case core::BuiltinValue::kPointSize:
- return SpvBuiltInPointSize;
- case core::BuiltinValue::kFragDepth:
- return SpvBuiltInFragDepth;
- case core::BuiltinValue::kFrontFacing:
- return SpvBuiltInFrontFacing;
- case core::BuiltinValue::kGlobalInvocationId:
- return SpvBuiltInGlobalInvocationId;
- case core::BuiltinValue::kInstanceIndex:
- return SpvBuiltInInstanceIndex;
- case core::BuiltinValue::kLocalInvocationId:
- return SpvBuiltInLocalInvocationId;
- case core::BuiltinValue::kLocalInvocationIndex:
- return SpvBuiltInLocalInvocationIndex;
- case core::BuiltinValue::kNumWorkgroups:
- return SpvBuiltInNumWorkgroups;
- case core::BuiltinValue::kPosition:
- if (addrspace == core::AddressSpace::kOut) {
- // Vertex output.
- return SpvBuiltInPosition;
- } else {
- // Fragment input.
- return SpvBuiltInFragCoord;
+ /// A function type used for an OpTypeFunction declaration.
+ struct FunctionType {
+ uint32_t return_type_id;
+ Vector<uint32_t, 4> param_type_ids;
+
+ /// Hasher provides a hash function for the FunctionType.
+ struct Hasher {
+ /// @param ft the FunctionType to create a hash for
+ /// @return the hash value
+ inline std::size_t operator()(const FunctionType& ft) const {
+ size_t hash = Hash(ft.return_type_id);
+ for (auto& p : ft.param_type_ids) {
+ hash = HashCombine(hash, p);
+ }
+ return hash;
}
- case core::BuiltinValue::kSampleIndex:
- module_.PushCapability(SpvCapabilitySampleRateShading);
- return SpvBuiltInSampleId;
- case core::BuiltinValue::kSampleMask:
- return SpvBuiltInSampleMask;
- case core::BuiltinValue::kSubgroupInvocationId:
- module_.PushCapability(SpvCapabilityGroupNonUniform);
- return SpvBuiltInSubgroupLocalInvocationId;
- case core::BuiltinValue::kSubgroupSize:
- module_.PushCapability(SpvCapabilityGroupNonUniform);
- return SpvBuiltInSubgroupSize;
- case core::BuiltinValue::kVertexIndex:
- return SpvBuiltInVertexIndex;
- case core::BuiltinValue::kWorkgroupId:
- return SpvBuiltInWorkgroupId;
- case core::BuiltinValue::kUndefined:
- return SpvBuiltInMax;
- }
- return SpvBuiltInMax;
-}
+ };
-uint32_t Printer::Constant(core::ir::Constant* constant) {
- // If it is a literal operand, just return the value.
- if (auto* literal = constant->As<spirv::ir::LiteralOperand>()) {
- return literal->Value()->ValueAs<uint32_t>();
- }
-
- auto id = Constant(constant->Value());
-
- // Set the name for the SPIR-V result ID if provided in the module.
- if (auto name = ir_.NameOf(constant)) {
- module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
- }
-
- return id;
-}
-
-uint32_t Printer::Constant(const core::constant::Value* constant) {
- return constants_.GetOrCreate(constant, [&] {
- auto* ty = constant->Type();
-
- // Use OpConstantNull for zero-valued composite constants.
- if (!ty->Is<core::type::Scalar>() && constant->AllZero()) {
- return ConstantNull(ty);
+ /// Equality operator for FunctionType.
+ bool operator==(const FunctionType& other) const {
+ return (param_type_ids == other.param_type_ids) &&
+ (return_type_id == other.return_type_id);
}
-
- auto id = module_.NextId();
- Switch(
- ty, //
- [&](const core::type::Bool*) {
- module_.PushType(
- constant->ValueAs<bool>() ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse,
- {Type(ty), id});
- },
- [&](const core::type::I32*) {
- module_.PushType(spv::Op::OpConstant, {Type(ty), id, constant->ValueAs<u32>()});
- },
- [&](const core::type::U32*) {
- module_.PushType(spv::Op::OpConstant,
- {Type(ty), id, U32Operand(constant->ValueAs<i32>())});
- },
- [&](const core::type::F32*) {
- module_.PushType(spv::Op::OpConstant, {Type(ty), id, constant->ValueAs<f32>()});
- },
- [&](const core::type::F16*) {
- module_.PushType(
- spv::Op::OpConstant,
- {Type(ty), id, U32Operand(constant->ValueAs<f16>().BitsRepresentation())});
- },
- [&](const core::type::Vector* vec) {
- OperandList operands = {Type(ty), id};
- for (uint32_t i = 0; i < vec->Width(); i++) {
- operands.push_back(Constant(constant->Index(i)));
- }
- module_.PushType(spv::Op::OpConstantComposite, operands);
- },
- [&](const core::type::Matrix* mat) {
- OperandList operands = {Type(ty), id};
- for (uint32_t i = 0; i < mat->columns(); i++) {
- operands.push_back(Constant(constant->Index(i)));
- }
- module_.PushType(spv::Op::OpConstantComposite, operands);
- },
- [&](const core::type::Array* arr) {
- TINT_ASSERT(arr->ConstantCount());
- OperandList operands = {Type(ty), id};
- for (uint32_t i = 0; i < arr->ConstantCount(); i++) {
- operands.push_back(Constant(constant->Index(i)));
- }
- module_.PushType(spv::Op::OpConstantComposite, operands);
- },
- [&](const core::type::Struct* str) {
- OperandList operands = {Type(ty), id};
- for (uint32_t i = 0; i < str->Members().Length(); i++) {
- operands.push_back(Constant(constant->Index(i)));
- }
- module_.PushType(spv::Op::OpConstantComposite, operands);
- },
- [&](Default) { TINT_ICE() << "unhandled constant type: " << ty->FriendlyName(); });
- return id;
- });
-}
-
-uint32_t Printer::ConstantNull(const core::type::Type* type) {
- return constant_nulls_.GetOrCreate(type, [&] {
- auto id = module_.NextId();
- module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
- return id;
- });
-}
-
-uint32_t Printer::Undef(const core::type::Type* type) {
- return undef_values_.GetOrCreate(type, [&] {
- auto id = module_.NextId();
- module_.PushType(spv::Op::OpUndef, {Type(type), id});
- return id;
- });
-}
-
-uint32_t Printer::Type(const core::type::Type* ty) {
- ty = DedupType(ty, ir_.Types());
- return types_.GetOrCreate(ty, [&] {
- auto id = module_.NextId();
- Switch(
- ty, //
- [&](const core::type::Void*) { module_.PushType(spv::Op::OpTypeVoid, {id}); },
- [&](const core::type::Bool*) { module_.PushType(spv::Op::OpTypeBool, {id}); },
- [&](const core::type::I32*) {
- module_.PushType(spv::Op::OpTypeInt, {id, 32u, 1u});
- },
- [&](const core::type::U32*) {
- module_.PushType(spv::Op::OpTypeInt, {id, 32u, 0u});
- },
- [&](const core::type::F32*) {
- module_.PushType(spv::Op::OpTypeFloat, {id, 32u});
- },
- [&](const core::type::F16*) {
- module_.PushCapability(SpvCapabilityFloat16);
- module_.PushCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
- module_.PushCapability(SpvCapabilityStorageBuffer16BitAccess);
- module_.PushCapability(SpvCapabilityStorageInputOutput16);
- module_.PushType(spv::Op::OpTypeFloat, {id, 16u});
- },
- [&](const core::type::Vector* vec) {
- module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()});
- },
- [&](const core::type::Matrix* mat) {
- module_.PushType(spv::Op::OpTypeMatrix,
- {id, Type(mat->ColumnType()), mat->columns()});
- },
- [&](const core::type::Array* arr) {
- if (arr->ConstantCount()) {
- auto* count = b_.ConstantValue(u32(arr->ConstantCount().value()));
- module_.PushType(spv::Op::OpTypeArray,
- {id, Type(arr->ElemType()), Constant(count)});
- } else {
- TINT_ASSERT(arr->Count()->Is<core::type::RuntimeArrayCount>());
- module_.PushType(spv::Op::OpTypeRuntimeArray, {id, Type(arr->ElemType())});
- }
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationArrayStride), arr->Stride()});
- },
- [&](const core::type::Pointer* ptr) {
- module_.PushType(
- spv::Op::OpTypePointer,
- {id, U32Operand(StorageClass(ptr->AddressSpace())), Type(ptr->StoreType())});
- },
- [&](const core::type::Struct* str) { EmitStructType(id, str); },
- [&](const core::type::Texture* tex) { EmitTextureType(id, tex); },
- [&](const core::type::Sampler*) { module_.PushType(spv::Op::OpTypeSampler, {id}); },
- [&](const type::SampledImage* s) {
- module_.PushType(spv::Op::OpTypeSampledImage, {id, Type(s->Image())});
- },
- [&](Default) { TINT_ICE() << "unhandled type: " << ty->FriendlyName(); });
- return id;
- });
-}
-
-uint32_t Printer::Value(core::ir::Instruction* inst) {
- return Value(inst->Result());
-}
-
-uint32_t Printer::Value(core::ir::Value* value) {
- return Switch(
- value, //
- [&](core::ir::Constant* constant) { return Constant(constant); },
- [&](core::ir::Value*) {
- return values_.GetOrCreate(value, [&] { return module_.NextId(); });
- });
-}
-
-uint32_t Printer::Label(core::ir::Block* block) {
- return block_labels_.GetOrCreate(block, [&] { return module_.NextId(); });
-}
-
-void Printer::EmitStructType(uint32_t id, const core::type::Struct* str) {
- // Helper to return `type` or a potentially nested array element type within `type` as a matrix
- // type, or nullptr if no such matrix type is present.
- auto get_nested_matrix_type = [&](const core::type::Type* type) {
- while (auto* arr = type->As<core::type::Array>()) {
- type = arr->ElemType();
- }
- return type->As<core::type::Matrix>();
};
- OperandList operands = {id};
- for (auto* member : str->Members()) {
- operands.push_back(Type(member->Type()));
+ /// The map of types to their result IDs.
+ Hashmap<const core::type::Type*, uint32_t, 8> types_;
- // Generate struct member offset decoration.
- module_.PushAnnot(
- spv::Op::OpMemberDecorate,
- {operands[0], member->Index(), U32Operand(SpvDecorationOffset), member->Offset()});
+ /// The map of function types to their result IDs.
+ Hashmap<FunctionType, uint32_t, 8, FunctionType::Hasher> function_types_;
- // Emit matrix layout decorations if necessary.
- if (auto* matrix_type = get_nested_matrix_type(member->Type())) {
- const uint32_t effective_row_count = (matrix_type->rows() == 2) ? 2 : 4;
- module_.PushAnnot(spv::Op::OpMemberDecorate,
- {id, member->Index(), U32Operand(SpvDecorationColMajor)});
- module_.PushAnnot(spv::Op::OpMemberDecorate,
- {id, member->Index(), U32Operand(SpvDecorationMatrixStride),
- Operand(effective_row_count * matrix_type->type()->Size())});
+ /// The map of constants to their result IDs.
+ Hashmap<const core::constant::Value*, uint32_t, 16> constants_;
+
+ /// The map of types to the result IDs of their OpConstantNull instructions.
+ Hashmap<const core::type::Type*, uint32_t, 4> constant_nulls_;
+
+ /// The map of types to the result IDs of their OpUndef instructions.
+ Hashmap<const core::type::Type*, uint32_t, 4> undef_values_;
+
+ /// The map of non-constant values to their result IDs.
+ Hashmap<core::ir::Value*, uint32_t, 8> values_;
+
+ /// The map of blocks to the IDs of their label instructions.
+ Hashmap<core::ir::Block*, uint32_t, 8> block_labels_;
+
+ /// The map of control instructions to the IDs of the label of their SPIR-V merge blocks.
+ Hashmap<core::ir::ControlInstruction*, uint32_t, 8> merge_block_labels_;
+
+ /// The map of extended instruction set names to their result IDs.
+ Hashmap<std::string_view, uint32_t, 2> imports_;
+
+ /// The current function that is being emitted.
+ Function current_function_;
+
+ /// The merge block for the current if statement
+ uint32_t if_merge_label_ = 0;
+
+ /// The header block for the current loop statement
+ uint32_t loop_header_label_ = 0;
+
+ /// The merge block for the current loop statement
+ uint32_t loop_merge_label_ = 0;
+
+ /// The merge block for the current switch statement
+ uint32_t switch_merge_label_ = 0;
+
+ bool zero_init_workgroup_memory_ = false;
+
+ /// Builds the SPIR-V from the IR
+ Result<SuccessType> Generate() {
+ auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "SPIR-V writer");
+ if (!valid) {
+ return valid.Failure();
}
- if (member->Name().IsValid()) {
- module_.PushDebug(spv::Op::OpMemberName,
- {operands[0], member->Index(), Operand(member->Name().Name())});
- }
- }
- module_.PushType(spv::Op::OpTypeStruct, std::move(operands));
+ module_.PushCapability(SpvCapabilityShader);
+ module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
+ U32Operand(SpvMemoryModelGLSL450)});
- // Add a Block decoration if necessary.
- if (str->StructFlags().Contains(core::type::StructFlag::kBlock)) {
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBlock)});
+ // Emit module-scope declarations.
+ EmitRootBlock(ir_.root_block);
+
+ // Emit functions.
+ for (auto* func : ir_.functions) {
+ EmitFunction(func);
+ }
+
+ return Success;
}
- if (str->Name().IsValid()) {
- module_.PushDebug(spv::Op::OpName, {operands[0], Operand(str->Name().Name())});
- }
-}
-
-void Printer::EmitTextureType(uint32_t id, const core::type::Texture* texture) {
- uint32_t sampled_type = Switch(
- texture, //
- [&](const core::type::SampledTexture* t) { return Type(t->type()); },
- [&](const core::type::MultisampledTexture* t) { return Type(t->type()); },
- [&](const core::type::StorageTexture* t) { return Type(t->type()); },
- [&](Default) {
- TINT_ICE() << "unhandled texture type: " << texture->TypeInfo().name;
- return 0u;
- });
-
- uint32_t dim = SpvDimMax;
- uint32_t array = 0u;
- switch (texture->dim()) {
- case core::type::TextureDimension::kNone: {
- break;
+ /// Convert a builtin to the corresponding SPIR-V enum value, taking into account the target
+ /// address space. Adds any capabilities needed for the builtin.
+ /// @param builtin the builtin to convert
+ /// @param addrspace the address space the builtin is being used in
+ /// @returns the enum value of the corresponding SPIR-V builtin
+ uint32_t Builtin(core::BuiltinValue builtin, core::AddressSpace addrspace) {
+ switch (builtin) {
+ case core::BuiltinValue::kPointSize:
+ return SpvBuiltInPointSize;
+ case core::BuiltinValue::kFragDepth:
+ return SpvBuiltInFragDepth;
+ case core::BuiltinValue::kFrontFacing:
+ return SpvBuiltInFrontFacing;
+ case core::BuiltinValue::kGlobalInvocationId:
+ return SpvBuiltInGlobalInvocationId;
+ case core::BuiltinValue::kInstanceIndex:
+ return SpvBuiltInInstanceIndex;
+ case core::BuiltinValue::kLocalInvocationId:
+ return SpvBuiltInLocalInvocationId;
+ case core::BuiltinValue::kLocalInvocationIndex:
+ return SpvBuiltInLocalInvocationIndex;
+ case core::BuiltinValue::kNumWorkgroups:
+ return SpvBuiltInNumWorkgroups;
+ case core::BuiltinValue::kPosition:
+ if (addrspace == core::AddressSpace::kOut) {
+ // Vertex output.
+ return SpvBuiltInPosition;
+ } else {
+ // Fragment input.
+ return SpvBuiltInFragCoord;
+ }
+ case core::BuiltinValue::kSampleIndex:
+ module_.PushCapability(SpvCapabilitySampleRateShading);
+ return SpvBuiltInSampleId;
+ case core::BuiltinValue::kSampleMask:
+ return SpvBuiltInSampleMask;
+ case core::BuiltinValue::kSubgroupInvocationId:
+ module_.PushCapability(SpvCapabilityGroupNonUniform);
+ return SpvBuiltInSubgroupLocalInvocationId;
+ case core::BuiltinValue::kSubgroupSize:
+ module_.PushCapability(SpvCapabilityGroupNonUniform);
+ return SpvBuiltInSubgroupSize;
+ case core::BuiltinValue::kVertexIndex:
+ return SpvBuiltInVertexIndex;
+ case core::BuiltinValue::kWorkgroupId:
+ return SpvBuiltInWorkgroupId;
+ case core::BuiltinValue::kUndefined:
+ return SpvBuiltInMax;
}
- case core::type::TextureDimension::k1d: {
- dim = SpvDim1D;
- if (texture->Is<core::type::SampledTexture>()) {
- module_.PushCapability(SpvCapabilitySampled1D);
- } else if (texture->Is<core::type::StorageTexture>()) {
- module_.PushCapability(SpvCapabilityImage1D);
- }
- break;
- }
- case core::type::TextureDimension::k2d: {
- dim = SpvDim2D;
- break;
- }
- case core::type::TextureDimension::k2dArray: {
- dim = SpvDim2D;
- array = 1u;
- break;
- }
- case core::type::TextureDimension::k3d: {
- dim = SpvDim3D;
- break;
- }
- case core::type::TextureDimension::kCube: {
- dim = SpvDimCube;
- break;
- }
- case core::type::TextureDimension::kCubeArray: {
- dim = SpvDimCube;
- array = 1u;
- if (texture->Is<core::type::SampledTexture>()) {
- module_.PushCapability(SpvCapabilitySampledCubeArray);
- }
- break;
- }
+ return SpvBuiltInMax;
}
- // The Vulkan spec says: The "Depth" operand of OpTypeImage is ignored.
- // In SPIRV, 0 means not depth, 1 means depth, and 2 means unknown.
- // Using anything other than 0 is problematic on various Vulkan drivers.
- uint32_t depth = 0u;
-
- uint32_t ms = 0u;
- if (texture->Is<core::type::MultisampledTexture>()) {
- ms = 1u;
- }
-
- uint32_t sampled = 2u;
- if (texture->IsAnyOf<core::type::MultisampledTexture, core::type::SampledTexture>()) {
- sampled = 1u;
- }
-
- uint32_t format = SpvImageFormat_::SpvImageFormatUnknown;
- if (auto* st = texture->As<core::type::StorageTexture>()) {
- format = TexelFormat(st->texel_format());
- }
-
- module_.PushType(spv::Op::OpTypeImage,
- {id, sampled_type, dim, depth, array, ms, sampled, format});
-}
-
-void Printer::EmitFunction(core::ir::Function* func) {
- auto id = Value(func);
-
- // Emit the function name.
- module_.PushDebug(spv::Op::OpName, {id, Operand(ir_.NameOf(func).Name())});
-
- // Emit OpEntryPoint and OpExecutionMode declarations if needed.
- if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
- EmitEntryPoint(func, id);
- }
-
- // Get the ID for the return type.
- auto return_type_id = Type(func->ReturnType());
-
- FunctionType function_type{return_type_id, {}};
- InstructionList params;
-
- // Generate function parameter declarations and add their type IDs to the function signature.
- for (auto* param : func->Params()) {
- auto param_type_id = Type(param->Type());
- auto param_id = Value(param);
- params.push_back(Instruction(spv::Op::OpFunctionParameter, {param_type_id, param_id}));
- function_type.param_type_ids.Push(param_type_id);
- if (auto name = ir_.NameOf(param)) {
- module_.PushDebug(spv::Op::OpName, {param_id, Operand(name.Name())});
- }
- }
-
- // Get the ID for the function type (creating it if needed).
- auto function_type_id = function_types_.GetOrCreate(function_type, [&] {
- auto func_ty_id = module_.NextId();
- OperandList operands = {func_ty_id, return_type_id};
- operands.insert(operands.end(), function_type.param_type_ids.begin(),
- function_type.param_type_ids.end());
- module_.PushType(spv::Op::OpTypeFunction, operands);
- return func_ty_id;
- });
-
- // Declare the function.
- auto decl =
- Instruction{spv::Op::OpFunction,
- {return_type_id, id, U32Operand(SpvFunctionControlMaskNone), function_type_id}};
-
- // Create a function that we will add instructions to.
- auto entry_block = module_.NextId();
- current_function_ = Function(decl, entry_block, std::move(params));
- TINT_DEFER(current_function_ = Function());
-
- // Emit the body of the function.
- EmitBlock(func->Block());
-
- // Add the function to the module.
- module_.PushFunction(current_function_);
-}
-
-void Printer::EmitEntryPoint(core::ir::Function* func, uint32_t id) {
- SpvExecutionModel stage = SpvExecutionModelMax;
- switch (func->Stage()) {
- case core::ir::Function::PipelineStage::kCompute: {
- stage = SpvExecutionModelGLCompute;
- module_.PushExecutionMode(
- spv::Op::OpExecutionMode,
- {id, U32Operand(SpvExecutionModeLocalSize), func->WorkgroupSize()->at(0),
- func->WorkgroupSize()->at(1), func->WorkgroupSize()->at(2)});
- break;
- }
- case core::ir::Function::PipelineStage::kFragment: {
- stage = SpvExecutionModelFragment;
- module_.PushExecutionMode(spv::Op::OpExecutionMode,
- {id, U32Operand(SpvExecutionModeOriginUpperLeft)});
- break;
- }
- case core::ir::Function::PipelineStage::kVertex: {
- stage = SpvExecutionModelVertex;
- break;
- }
- case core::ir::Function::PipelineStage::kUndefined:
- TINT_ICE() << "undefined pipeline stage for entry point";
- return;
- }
-
- OperandList operands = {U32Operand(stage), id, ir_.NameOf(func).Name()};
-
- // Add the list of all referenced shader IO variables.
- for (auto* global : *ir_.root_block) {
- auto* var = global->As<core::ir::Var>();
- if (!var) {
- continue;
+ /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
+ /// @param constant the constant to get the ID for
+ /// @returns the result ID of the constant
+ uint32_t Constant(core::ir::Constant* constant) {
+ // If it is a literal operand, just return the value.
+ if (auto* literal = constant->As<spirv::ir::LiteralOperand>()) {
+ return literal->Value()->ValueAs<uint32_t>();
}
- auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
- if (!(ptr->AddressSpace() == core::AddressSpace::kIn ||
- ptr->AddressSpace() == core::AddressSpace::kOut)) {
- continue;
- }
-
- // Determine if this IO variable is used by the entry point.
- bool used = false;
- for (const auto& use : var->Result()->Usages()) {
- auto* block = use.instruction->Block();
- while (block->Parent()) {
- block = block->Parent()->Block();
- }
- if (block == func->Block()) {
- used = true;
- break;
- }
- }
- if (!used) {
- continue;
- }
- operands.push_back(Value(var));
-
- // Add the `DepthReplacing` execution mode if `frag_depth` is used.
- if (var->Attributes().builtin == core::BuiltinValue::kFragDepth) {
- module_.PushExecutionMode(spv::Op::OpExecutionMode,
- {id, U32Operand(SpvExecutionModeDepthReplacing)});
- }
- }
-
- module_.PushEntryPoint(spv::Op::OpEntryPoint, operands);
-}
-
-void Printer::EmitRootBlock(core::ir::Block* root_block) {
- for (auto* inst : *root_block) {
- Switch(
- inst, //
- [&](core::ir::Var* v) { return EmitVar(v); },
- [&](Default) {
- TINT_ICE() << "unimplemented root block instruction: " << inst->TypeInfo().name;
- });
- }
-}
-
-void Printer::EmitBlock(core::ir::Block* block) {
- // Emit the label.
- // Skip if this is the function's entry block, as it will be emitted by the function object.
- if (!current_function_.instructions().empty()) {
- current_function_.push_inst(spv::Op::OpLabel, {Label(block)});
- }
-
- // If there are no instructions in the block, it's a dead end, so we shouldn't be able to get
- // here to begin with.
- if (block->IsEmpty()) {
- current_function_.push_inst(spv::Op::OpUnreachable, {});
- return;
- }
-
- if (auto* mib = block->As<core::ir::MultiInBlock>()) {
- // Emit all OpPhi nodes for incoming branches to block.
- EmitIncomingPhis(mib);
- }
-
- // Emit the block's statements.
- EmitBlockInstructions(block);
-}
-
-void Printer::EmitIncomingPhis(core::ir::MultiInBlock* block) {
- // Emit Phi nodes for all the incoming block parameters
- for (size_t param_idx = 0; param_idx < block->Params().Length(); param_idx++) {
- auto* param = block->Params()[param_idx];
- OperandList ops{Type(param->Type()), Value(param)};
-
- for (auto* incoming : block->InboundSiblingBranches()) {
- auto* arg = incoming->Args()[param_idx];
- ops.push_back(Value(arg));
- ops.push_back(GetTerminatorBlockLabel(incoming));
- }
-
- current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
- }
-}
-
-void Printer::EmitBlockInstructions(core::ir::Block* block) {
- for (auto* inst : *block) {
- Switch(
- inst, //
- [&](core::ir::Access* a) { EmitAccess(a); }, //
- [&](core::ir::Binary* b) { EmitBinary(b); }, //
- [&](core::ir::Bitcast* b) { EmitBitcast(b); }, //
- [&](core::ir::CoreBuiltinCall* b) { EmitCoreBuiltinCall(b); }, //
- [&](spirv::ir::BuiltinCall* b) { EmitSpirvBuiltinCall(b); }, //
- [&](core::ir::Construct* c) { EmitConstruct(c); }, //
- [&](core::ir::Convert* c) { EmitConvert(c); }, //
- [&](core::ir::Load* l) { EmitLoad(l); }, //
- [&](core::ir::LoadVectorElement* l) { EmitLoadVectorElement(l); }, //
- [&](core::ir::Loop* l) { EmitLoop(l); }, //
- [&](core::ir::Switch* sw) { EmitSwitch(sw); }, //
- [&](core::ir::Swizzle* s) { EmitSwizzle(s); }, //
- [&](core::ir::Store* s) { EmitStore(s); }, //
- [&](core::ir::StoreVectorElement* s) { EmitStoreVectorElement(s); }, //
- [&](core::ir::UserCall* c) { EmitUserCall(c); }, //
- [&](core::ir::Unary* u) { EmitUnary(u); }, //
- [&](core::ir::Var* v) { EmitVar(v); }, //
- [&](core::ir::Let* l) { EmitLet(l); }, //
- [&](core::ir::If* i) { EmitIf(i); }, //
- [&](core::ir::Terminator* t) { EmitTerminator(t); }, //
- [&](Default) { TINT_ICE() << "unimplemented instruction: " << inst->TypeInfo().name; });
+ auto id = Constant(constant->Value());
// Set the name for the SPIR-V result ID if provided in the module.
- if (inst->Result() && !inst->Is<core::ir::Var>()) {
- if (auto name = ir_.NameOf(inst)) {
- module_.PushDebug(spv::Op::OpName, {Value(inst), Operand(name.Name())});
- }
- }
- }
-
- if (block->IsEmpty()) {
- // If the last emitted instruction is not a branch, then this should be unreachable.
- current_function_.push_inst(spv::Op::OpUnreachable, {});
- }
-}
-
-void Printer::EmitTerminator(core::ir::Terminator* t) {
- tint::Switch( //
- t, //
- [&](core::ir::Return*) {
- if (!t->Args().IsEmpty()) {
- TINT_ASSERT(t->Args().Length() == 1u);
- OperandList operands;
- operands.push_back(Value(t->Args()[0]));
- current_function_.push_inst(spv::Op::OpReturnValue, operands);
- } else {
- current_function_.push_inst(spv::Op::OpReturn, {});
- }
- return;
- },
- [&](core::ir::BreakIf* breakif) {
- current_function_.push_inst(spv::Op::OpBranchConditional,
- {
- Value(breakif->Condition()),
- loop_merge_label_,
- loop_header_label_,
- });
- },
- [&](core::ir::Continue* cont) {
- current_function_.push_inst(spv::Op::OpBranch, {Label(cont->Loop()->Continuing())});
- },
- [&](core::ir::ExitIf*) {
- current_function_.push_inst(spv::Op::OpBranch, {if_merge_label_});
- },
- [&](core::ir::ExitLoop*) {
- current_function_.push_inst(spv::Op::OpBranch, {loop_merge_label_});
- },
- [&](core::ir::ExitSwitch*) {
- current_function_.push_inst(spv::Op::OpBranch, {switch_merge_label_});
- },
- [&](core::ir::NextIteration*) {
- current_function_.push_inst(spv::Op::OpBranch, {loop_header_label_});
- },
- [&](core::ir::TerminateInvocation*) { current_function_.push_inst(spv::Op::OpKill, {}); },
- [&](core::ir::Unreachable*) { current_function_.push_inst(spv::Op::OpUnreachable, {}); },
-
- [&](Default) { TINT_ICE() << "unimplemented branch: " << t->TypeInfo().name; });
-}
-
-void Printer::EmitIf(core::ir::If* i) {
- auto* true_block = i->True();
- auto* false_block = i->False();
-
- // Generate labels for the blocks. We emit the true or false block if it:
- // 1. contains instructions other then the branch, or
- // 2. branches somewhere instead of exiting the loop (e.g. return or break), or
- // 3. the if returns a value
- // Otherwise we skip them and branch straight to the merge block.
- uint32_t merge_label = GetMergeLabel(i);
- TINT_SCOPED_ASSIGNMENT(if_merge_label_, merge_label);
-
- uint32_t true_label = merge_label;
- uint32_t false_label = merge_label;
- if (true_block->Length() > 1 || i->HasResults() ||
- (true_block->HasTerminator() && !true_block->Terminator()->Is<core::ir::ExitIf>())) {
- true_label = Label(true_block);
- }
- if (false_block->Length() > 1 || i->HasResults() ||
- (false_block->HasTerminator() && !false_block->Terminator()->Is<core::ir::ExitIf>())) {
- false_label = Label(false_block);
- }
-
- // Emit the OpSelectionMerge and OpBranchConditional instructions.
- current_function_.push_inst(spv::Op::OpSelectionMerge,
- {merge_label, U32Operand(SpvSelectionControlMaskNone)});
- current_function_.push_inst(spv::Op::OpBranchConditional,
- {Value(i->Condition()), true_label, false_label});
-
- // Emit the `true` and `false` blocks, if they're not being skipped.
- if (true_label != merge_label) {
- EmitBlock(true_block);
- }
- if (false_label != merge_label) {
- EmitBlock(false_block);
- }
-
- current_function_.push_inst(spv::Op::OpLabel, {merge_label});
-
- // Emit the OpPhis for the ExitIfs
- EmitExitPhis(i);
-}
-
-void Printer::EmitAccess(core::ir::Access* access) {
- auto* ty = access->Result()->Type();
-
- auto id = Value(access);
- OperandList operands = {Type(ty), id, Value(access->Object())};
-
- if (ty->Is<core::type::Pointer>()) {
- // Use OpAccessChain for accesses into pointer types.
- for (auto* idx : access->Indices()) {
- operands.push_back(Value(idx));
- }
- current_function_.push_inst(spv::Op::OpAccessChain, std::move(operands));
- return;
- }
-
- // For non-pointer types, we assume that the indices are constants and use OpCompositeExtract.
- // If we hit a non-constant index into a vector type, use OpVectorExtractDynamic for it.
- auto* source_ty = access->Object()->Type();
- for (auto* idx : access->Indices()) {
- if (auto* constant = idx->As<core::ir::Constant>()) {
- // Push the index to the chain and update the current type.
- auto i = constant->Value()->ValueAs<u32>();
- operands.push_back(i);
- source_ty = source_ty->Element(i);
- } else {
- // The VarForDynamicIndex transform ensures that only value types that are vectors
- // will be dynamically indexed, as we can use OpVectorExtractDynamic for this case.
- TINT_ASSERT(source_ty->Is<core::type::Vector>());
-
- // If this wasn't the first access in the chain then emit the chain so far as an
- // OpCompositeExtract, creating a new result ID for the resulting vector.
- auto vec_id = Value(access->Object());
- if (operands.size() > 3) {
- vec_id = module_.NextId();
- operands[0] = Type(source_ty);
- operands[1] = vec_id;
- current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
- }
-
- // Now emit the OpVectorExtractDynamic instruction.
- operands = {Type(ty), id, vec_id, Value(idx)};
- current_function_.push_inst(spv::Op::OpVectorExtractDynamic, std::move(operands));
- return;
- }
- }
- current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
-}
-
-void Printer::EmitBinary(core::ir::Binary* binary) {
- auto id = Value(binary);
- auto lhs = Value(binary->LHS());
- auto rhs = Value(binary->RHS());
- auto* ty = binary->Result()->Type();
- auto* lhs_ty = binary->LHS()->Type();
-
- // Determine the opcode.
- spv::Op op = spv::Op::Max;
- switch (binary->Op()) {
- case core::ir::BinaryOp::kAdd: {
- op = ty->is_integer_scalar_or_vector() ? spv::Op::OpIAdd : spv::Op::OpFAdd;
- break;
- }
- case core::ir::BinaryOp::kDivide: {
- if (ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSDiv;
- } else if (ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpUDiv;
- } else if (ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFDiv;
- }
- break;
- }
- case core::ir::BinaryOp::kMultiply: {
- if (ty->is_integer_scalar_or_vector()) {
- op = spv::Op::OpIMul;
- } else if (ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFMul;
- }
- break;
- }
- case core::ir::BinaryOp::kSubtract: {
- op = ty->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
- break;
- }
- case core::ir::BinaryOp::kModulo: {
- if (ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSRem;
- } else if (ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpUMod;
- } else if (ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFRem;
- }
- break;
+ if (auto name = ir_.NameOf(constant)) {
+ module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
}
- case core::ir::BinaryOp::kAnd: {
- if (ty->is_integer_scalar_or_vector()) {
- op = spv::Op::OpBitwiseAnd;
- } else if (ty->is_bool_scalar_or_vector()) {
- op = spv::Op::OpLogicalAnd;
- }
- break;
- }
- case core::ir::BinaryOp::kOr: {
- if (ty->is_integer_scalar_or_vector()) {
- op = spv::Op::OpBitwiseOr;
- } else if (ty->is_bool_scalar_or_vector()) {
- op = spv::Op::OpLogicalOr;
- }
- break;
- }
- case core::ir::BinaryOp::kXor: {
- op = spv::Op::OpBitwiseXor;
- break;
- }
-
- case core::ir::BinaryOp::kShiftLeft: {
- op = spv::Op::OpShiftLeftLogical;
- break;
- }
- case core::ir::BinaryOp::kShiftRight: {
- if (ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpShiftRightArithmetic;
- } else if (ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpShiftRightLogical;
- }
- break;
- }
-
- case core::ir::BinaryOp::kEqual: {
- if (lhs_ty->is_bool_scalar_or_vector()) {
- op = spv::Op::OpLogicalEqual;
- } else if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdEqual;
- } else if (lhs_ty->is_integer_scalar_or_vector()) {
- op = spv::Op::OpIEqual;
- }
- break;
- }
- case core::ir::BinaryOp::kNotEqual: {
- if (lhs_ty->is_bool_scalar_or_vector()) {
- op = spv::Op::OpLogicalNotEqual;
- } else if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdNotEqual;
- } else if (lhs_ty->is_integer_scalar_or_vector()) {
- op = spv::Op::OpINotEqual;
- }
- break;
- }
- case core::ir::BinaryOp::kGreaterThan: {
- if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdGreaterThan;
- } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSGreaterThan;
- } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpUGreaterThan;
- }
- break;
- }
- case core::ir::BinaryOp::kGreaterThanEqual: {
- if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdGreaterThanEqual;
- } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSGreaterThanEqual;
- } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpUGreaterThanEqual;
- }
- break;
- }
- case core::ir::BinaryOp::kLessThan: {
- if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdLessThan;
- } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSLessThan;
- } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpULessThan;
- }
- break;
- }
- case core::ir::BinaryOp::kLessThanEqual: {
- if (lhs_ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFOrdLessThanEqual;
- } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSLessThanEqual;
- } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
- op = spv::Op::OpULessThanEqual;
- }
- break;
- }
+ return id;
}
- // Emit the instruction.
- current_function_.push_inst(op, {Type(ty), id, lhs, rhs});
-}
+ /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
+ /// @param constant the constant to get the ID for
+ /// @returns the result ID of the constant
+ uint32_t Constant(const core::constant::Value* constant) {
+ return constants_.GetOrCreate(constant, [&] {
+ auto* ty = constant->Type();
-void Printer::EmitBitcast(core::ir::Bitcast* bitcast) {
- auto* ty = bitcast->Result()->Type();
- if (ty == bitcast->Val()->Type()) {
- values_.Add(bitcast->Result(), Value(bitcast->Val()));
- return;
- }
- current_function_.push_inst(spv::Op::OpBitcast,
- {Type(ty), Value(bitcast), Value(bitcast->Val())});
-}
-
-void Printer::EmitSpirvBuiltinCall(spirv::ir::BuiltinCall* builtin) {
- auto id = Value(builtin);
-
- spv::Op op = spv::Op::Max;
- switch (builtin->Func()) {
- case spirv::BuiltinFn::kArrayLength:
- op = spv::Op::OpArrayLength;
- break;
- case spirv::BuiltinFn::kAtomicIadd:
- op = spv::Op::OpAtomicIAdd;
- break;
- case spirv::BuiltinFn::kAtomicIsub:
- op = spv::Op::OpAtomicISub;
- break;
- case spirv::BuiltinFn::kAtomicAnd:
- op = spv::Op::OpAtomicAnd;
- break;
- case spirv::BuiltinFn::kAtomicCompareExchange:
- op = spv::Op::OpAtomicCompareExchange;
- break;
- case spirv::BuiltinFn::kAtomicExchange:
- op = spv::Op::OpAtomicExchange;
- break;
- case spirv::BuiltinFn::kAtomicLoad:
- op = spv::Op::OpAtomicLoad;
- break;
- case spirv::BuiltinFn::kAtomicOr:
- op = spv::Op::OpAtomicOr;
- break;
- case spirv::BuiltinFn::kAtomicSmax:
- op = spv::Op::OpAtomicSMax;
- break;
- case spirv::BuiltinFn::kAtomicSmin:
- op = spv::Op::OpAtomicSMin;
- break;
- case spirv::BuiltinFn::kAtomicStore:
- op = spv::Op::OpAtomicStore;
- break;
- case spirv::BuiltinFn::kAtomicUmax:
- op = spv::Op::OpAtomicUMax;
- break;
- case spirv::BuiltinFn::kAtomicUmin:
- op = spv::Op::OpAtomicUMin;
- break;
- case spirv::BuiltinFn::kAtomicXor:
- op = spv::Op::OpAtomicXor;
- break;
- case spirv::BuiltinFn::kDot:
- op = spv::Op::OpDot;
- break;
- case spirv::BuiltinFn::kImageDrefGather:
- op = spv::Op::OpImageDrefGather;
- break;
- case spirv::BuiltinFn::kImageFetch:
- op = spv::Op::OpImageFetch;
- break;
- case spirv::BuiltinFn::kImageGather:
- op = spv::Op::OpImageGather;
- break;
- case spirv::BuiltinFn::kImageQuerySize:
- module_.PushCapability(SpvCapabilityImageQuery);
- op = spv::Op::OpImageQuerySize;
- break;
- case spirv::BuiltinFn::kImageQuerySizeLod:
- module_.PushCapability(SpvCapabilityImageQuery);
- op = spv::Op::OpImageQuerySizeLod;
- break;
- case spirv::BuiltinFn::kImageRead:
- op = spv::Op::OpImageRead;
- break;
- case spirv::BuiltinFn::kImageSampleImplicitLod:
- op = spv::Op::OpImageSampleImplicitLod;
- break;
- case spirv::BuiltinFn::kImageSampleExplicitLod:
- op = spv::Op::OpImageSampleExplicitLod;
- break;
- case spirv::BuiltinFn::kImageSampleDrefImplicitLod:
- op = spv::Op::OpImageSampleDrefImplicitLod;
- break;
- case spirv::BuiltinFn::kImageSampleDrefExplicitLod:
- op = spv::Op::OpImageSampleDrefExplicitLod;
- break;
- case spirv::BuiltinFn::kImageWrite:
- op = spv::Op::OpImageWrite;
- break;
- case spirv::BuiltinFn::kMatrixTimesMatrix:
- op = spv::Op::OpMatrixTimesMatrix;
- break;
- case spirv::BuiltinFn::kMatrixTimesScalar:
- op = spv::Op::OpMatrixTimesScalar;
- break;
- case spirv::BuiltinFn::kMatrixTimesVector:
- op = spv::Op::OpMatrixTimesVector;
- break;
- case spirv::BuiltinFn::kSampledImage:
- op = spv::Op::OpSampledImage;
- break;
- case spirv::BuiltinFn::kSdot:
- module_.PushExtension("SPV_KHR_integer_dot_product");
- module_.PushCapability(SpvCapabilityDotProductKHR);
- module_.PushCapability(SpvCapabilityDotProductInput4x8BitPackedKHR);
- op = spv::Op::OpSDot;
- break;
- case spirv::BuiltinFn::kSelect:
- op = spv::Op::OpSelect;
- break;
- case spirv::BuiltinFn::kUdot:
- module_.PushExtension("SPV_KHR_integer_dot_product");
- module_.PushCapability(SpvCapabilityDotProductKHR);
- module_.PushCapability(SpvCapabilityDotProductInput4x8BitPackedKHR);
- op = spv::Op::OpUDot;
- break;
- case spirv::BuiltinFn::kVectorTimesMatrix:
- op = spv::Op::OpVectorTimesMatrix;
- break;
- case spirv::BuiltinFn::kVectorTimesScalar:
- op = spv::Op::OpVectorTimesScalar;
- break;
- case spirv::BuiltinFn::kNone:
- TINT_ICE() << "undefined spirv ir function";
- return;
- }
-
- OperandList operands;
- if (!builtin->Result()->Type()->Is<core::type::Void>()) {
- operands = {Type(builtin->Result()->Type()), id};
- }
- for (auto* arg : builtin->Args()) {
- operands.push_back(Value(arg));
- }
- current_function_.push_inst(op, operands);
-}
-
-void Printer::EmitCoreBuiltinCall(core::ir::CoreBuiltinCall* builtin) {
- auto* result_ty = builtin->Result()->Type();
-
- if (builtin->Func() == core::BuiltinFn::kAbs &&
- result_ty->is_unsigned_integer_scalar_or_vector()) {
- // abs() is a no-op for unsigned integers.
- values_.Add(builtin->Result(), Value(builtin->Args()[0]));
- return;
- }
- if ((builtin->Func() == core::BuiltinFn::kAll || builtin->Func() == core::BuiltinFn::kAny) &&
- builtin->Args()[0]->Type()->Is<core::type::Bool>()) {
- // all() and any() are passthroughs for scalar arguments.
- values_.Add(builtin->Result(), Value(builtin->Args()[0]));
- return;
- }
-
- auto id = Value(builtin);
-
- spv::Op op = spv::Op::Max;
- OperandList operands = {Type(result_ty), id};
-
- // Helper to set up the opcode and operand list for a GLSL extended instruction.
- auto glsl_ext_inst = [&](enum GLSLstd450 inst) {
- constexpr const char* kGLSLstd450 = "GLSL.std.450";
- op = spv::Op::OpExtInst;
- operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&] {
- // Import the instruction set the first time it is requested.
- auto import = module_.NextId();
- module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)});
- return import;
- }));
- operands.push_back(U32Operand(inst));
- };
-
- // Determine the opcode.
- switch (builtin->Func()) {
- case core::BuiltinFn::kAbs:
- if (result_ty->is_float_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450FAbs);
- } else if (result_ty->is_signed_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450SAbs);
+ // Use OpConstantNull for zero-valued composite constants.
+ if (!ty->Is<core::type::Scalar>() && constant->AllZero()) {
+ return ConstantNull(ty);
}
- break;
- case core::BuiltinFn::kAll:
- op = spv::Op::OpAll;
- break;
- case core::BuiltinFn::kAny:
- op = spv::Op::OpAny;
- break;
- case core::BuiltinFn::kAcos:
- glsl_ext_inst(GLSLstd450Acos);
- break;
- case core::BuiltinFn::kAcosh:
- glsl_ext_inst(GLSLstd450Acosh);
- break;
- case core::BuiltinFn::kAsin:
- glsl_ext_inst(GLSLstd450Asin);
- break;
- case core::BuiltinFn::kAsinh:
- glsl_ext_inst(GLSLstd450Asinh);
- break;
- case core::BuiltinFn::kAtan:
- glsl_ext_inst(GLSLstd450Atan);
- break;
- case core::BuiltinFn::kAtan2:
- glsl_ext_inst(GLSLstd450Atan2);
- break;
- case core::BuiltinFn::kAtanh:
- glsl_ext_inst(GLSLstd450Atanh);
- break;
- case core::BuiltinFn::kClamp:
- if (result_ty->is_float_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450NClamp);
- } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450UClamp);
- } else if (result_ty->is_signed_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450SClamp);
- }
- break;
- case core::BuiltinFn::kCeil:
- glsl_ext_inst(GLSLstd450Ceil);
- break;
- case core::BuiltinFn::kCos:
- glsl_ext_inst(GLSLstd450Cos);
- break;
- case core::BuiltinFn::kCosh:
- glsl_ext_inst(GLSLstd450Cosh);
- break;
- case core::BuiltinFn::kCountOneBits:
- op = spv::Op::OpBitCount;
- break;
- case core::BuiltinFn::kCross:
- glsl_ext_inst(GLSLstd450Cross);
- break;
- case core::BuiltinFn::kDegrees:
- glsl_ext_inst(GLSLstd450Degrees);
- break;
- case core::BuiltinFn::kDeterminant:
- glsl_ext_inst(GLSLstd450Determinant);
- break;
- case core::BuiltinFn::kDistance:
- glsl_ext_inst(GLSLstd450Distance);
- break;
- case core::BuiltinFn::kDpdx:
- op = spv::Op::OpDPdx;
- break;
- case core::BuiltinFn::kDpdxCoarse:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpDPdxCoarse;
- break;
- case core::BuiltinFn::kDpdxFine:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpDPdxFine;
- break;
- case core::BuiltinFn::kDpdy:
- op = spv::Op::OpDPdy;
- break;
- case core::BuiltinFn::kDpdyCoarse:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpDPdyCoarse;
- break;
- case core::BuiltinFn::kDpdyFine:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpDPdyFine;
- break;
- case core::BuiltinFn::kExp:
- glsl_ext_inst(GLSLstd450Exp);
- break;
- case core::BuiltinFn::kExp2:
- glsl_ext_inst(GLSLstd450Exp2);
- break;
- case core::BuiltinFn::kExtractBits:
- op = result_ty->is_signed_integer_scalar_or_vector() ? spv::Op::OpBitFieldSExtract
- : spv::Op::OpBitFieldUExtract;
- break;
- case core::BuiltinFn::kFaceForward:
- glsl_ext_inst(GLSLstd450FaceForward);
- break;
- case core::BuiltinFn::kFloor:
- glsl_ext_inst(GLSLstd450Floor);
- break;
- case core::BuiltinFn::kFma:
- glsl_ext_inst(GLSLstd450Fma);
- break;
- case core::BuiltinFn::kFract:
- glsl_ext_inst(GLSLstd450Fract);
- break;
- case core::BuiltinFn::kFrexp:
- glsl_ext_inst(GLSLstd450FrexpStruct);
- break;
- case core::BuiltinFn::kFwidth:
- op = spv::Op::OpFwidth;
- break;
- case core::BuiltinFn::kFwidthCoarse:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpFwidthCoarse;
- break;
- case core::BuiltinFn::kFwidthFine:
- module_.PushCapability(SpvCapabilityDerivativeControl);
- op = spv::Op::OpFwidthFine;
- break;
- case core::BuiltinFn::kInsertBits:
- op = spv::Op::OpBitFieldInsert;
- break;
- case core::BuiltinFn::kInverseSqrt:
- glsl_ext_inst(GLSLstd450InverseSqrt);
- break;
- case core::BuiltinFn::kLdexp:
- glsl_ext_inst(GLSLstd450Ldexp);
- break;
- case core::BuiltinFn::kLength:
- glsl_ext_inst(GLSLstd450Length);
- break;
- case core::BuiltinFn::kLog:
- glsl_ext_inst(GLSLstd450Log);
- break;
- case core::BuiltinFn::kLog2:
- glsl_ext_inst(GLSLstd450Log2);
- break;
- case core::BuiltinFn::kMax:
- if (result_ty->is_float_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450FMax);
- } else if (result_ty->is_signed_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450SMax);
- } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450UMax);
- }
- break;
- case core::BuiltinFn::kMin:
- if (result_ty->is_float_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450FMin);
- } else if (result_ty->is_signed_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450SMin);
- } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450UMin);
- }
- break;
- case core::BuiltinFn::kMix:
- glsl_ext_inst(GLSLstd450FMix);
- break;
- case core::BuiltinFn::kModf:
- glsl_ext_inst(GLSLstd450ModfStruct);
- break;
- case core::BuiltinFn::kNormalize:
- glsl_ext_inst(GLSLstd450Normalize);
- break;
- case core::BuiltinFn::kPack2X16Float:
- glsl_ext_inst(GLSLstd450PackHalf2x16);
- break;
- case core::BuiltinFn::kPack2X16Snorm:
- glsl_ext_inst(GLSLstd450PackSnorm2x16);
- break;
- case core::BuiltinFn::kPack2X16Unorm:
- glsl_ext_inst(GLSLstd450PackUnorm2x16);
- break;
- case core::BuiltinFn::kPack4X8Snorm:
- glsl_ext_inst(GLSLstd450PackSnorm4x8);
- break;
- case core::BuiltinFn::kPack4X8Unorm:
- glsl_ext_inst(GLSLstd450PackUnorm4x8);
- break;
- case core::BuiltinFn::kPow:
- glsl_ext_inst(GLSLstd450Pow);
- break;
- case core::BuiltinFn::kQuantizeToF16:
- op = spv::Op::OpQuantizeToF16;
- break;
- case core::BuiltinFn::kRadians:
- glsl_ext_inst(GLSLstd450Radians);
- break;
- case core::BuiltinFn::kReflect:
- glsl_ext_inst(GLSLstd450Reflect);
- break;
- case core::BuiltinFn::kRefract:
- glsl_ext_inst(GLSLstd450Refract);
- break;
- case core::BuiltinFn::kReverseBits:
- op = spv::Op::OpBitReverse;
- break;
- case core::BuiltinFn::kRound:
- glsl_ext_inst(GLSLstd450RoundEven);
- break;
- case core::BuiltinFn::kSign:
- if (result_ty->is_float_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450FSign);
- } else if (result_ty->is_signed_integer_scalar_or_vector()) {
- glsl_ext_inst(GLSLstd450SSign);
- }
- break;
- case core::BuiltinFn::kSin:
- glsl_ext_inst(GLSLstd450Sin);
- break;
- case core::BuiltinFn::kSinh:
- glsl_ext_inst(GLSLstd450Sinh);
- break;
- case core::BuiltinFn::kSmoothstep:
- glsl_ext_inst(GLSLstd450SmoothStep);
- break;
- case core::BuiltinFn::kSqrt:
- glsl_ext_inst(GLSLstd450Sqrt);
- break;
- case core::BuiltinFn::kStep:
- glsl_ext_inst(GLSLstd450Step);
- break;
- case core::BuiltinFn::kStorageBarrier:
- op = spv::Op::OpControlBarrier;
- operands.clear();
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(
- Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::UniformMemory |
- spv::MemorySemanticsMask::AcquireRelease))));
- break;
- case core::BuiltinFn::kSubgroupBallot:
- module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
- op = spv::Op::OpGroupNonUniformBallot;
- operands.push_back(Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup))));
- operands.push_back(Constant(ir_.constant_values.Get(true)));
- break;
- case core::BuiltinFn::kSubgroupBroadcast:
- module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
- op = spv::Op::OpGroupNonUniformBroadcast;
- operands.push_back(Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup))));
- break;
- case core::BuiltinFn::kTan:
- glsl_ext_inst(GLSLstd450Tan);
- break;
- case core::BuiltinFn::kTanh:
- glsl_ext_inst(GLSLstd450Tanh);
- break;
- case core::BuiltinFn::kTextureBarrier:
- op = spv::Op::OpControlBarrier;
- operands.clear();
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(
- Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::ImageMemory |
- spv::MemorySemanticsMask::AcquireRelease))));
- break;
- case core::BuiltinFn::kTextureNumLevels:
- module_.PushCapability(SpvCapabilityImageQuery);
- op = spv::Op::OpImageQueryLevels;
- break;
- case core::BuiltinFn::kTextureNumSamples:
- module_.PushCapability(SpvCapabilityImageQuery);
- op = spv::Op::OpImageQuerySamples;
- break;
- case core::BuiltinFn::kTranspose:
- op = spv::Op::OpTranspose;
- break;
- case core::BuiltinFn::kTrunc:
- glsl_ext_inst(GLSLstd450Trunc);
- break;
- case core::BuiltinFn::kUnpack2X16Float:
- glsl_ext_inst(GLSLstd450UnpackHalf2x16);
- break;
- case core::BuiltinFn::kUnpack2X16Snorm:
- glsl_ext_inst(GLSLstd450UnpackSnorm2x16);
- break;
- case core::BuiltinFn::kUnpack2X16Unorm:
- glsl_ext_inst(GLSLstd450UnpackUnorm2x16);
- break;
- case core::BuiltinFn::kUnpack4X8Snorm:
- glsl_ext_inst(GLSLstd450UnpackSnorm4x8);
- break;
- case core::BuiltinFn::kUnpack4X8Unorm:
- glsl_ext_inst(GLSLstd450UnpackUnorm4x8);
- break;
- case core::BuiltinFn::kWorkgroupBarrier:
- op = spv::Op::OpControlBarrier;
- operands.clear();
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
- operands.push_back(
- Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::WorkgroupMemory |
- spv::MemorySemanticsMask::AcquireRelease))));
- break;
- default:
- TINT_ICE() << "unimplemented builtin function: " << builtin->Func();
- }
- TINT_ASSERT(op != spv::Op::Max);
- // Add the arguments to the builtin call.
- for (auto* arg : builtin->Args()) {
- operands.push_back(Value(arg));
+ auto id = module_.NextId();
+ Switch(
+ ty, //
+ [&](const core::type::Bool*) {
+ module_.PushType(constant->ValueAs<bool>() ? spv::Op::OpConstantTrue
+ : spv::Op::OpConstantFalse,
+ {Type(ty), id});
+ },
+ [&](const core::type::I32*) {
+ module_.PushType(spv::Op::OpConstant, {Type(ty), id, constant->ValueAs<u32>()});
+ },
+ [&](const core::type::U32*) {
+ module_.PushType(spv::Op::OpConstant,
+ {Type(ty), id, U32Operand(constant->ValueAs<i32>())});
+ },
+ [&](const core::type::F32*) {
+ module_.PushType(spv::Op::OpConstant, {Type(ty), id, constant->ValueAs<f32>()});
+ },
+ [&](const core::type::F16*) {
+ module_.PushType(
+ spv::Op::OpConstant,
+ {Type(ty), id, U32Operand(constant->ValueAs<f16>().BitsRepresentation())});
+ },
+ [&](const core::type::Vector* vec) {
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < vec->Width(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ },
+ [&](const core::type::Matrix* mat) {
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < mat->columns(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ },
+ [&](const core::type::Array* arr) {
+ TINT_ASSERT(arr->ConstantCount());
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < arr->ConstantCount(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ },
+ [&](const core::type::Struct* str) {
+ OperandList operands = {Type(ty), id};
+ for (uint32_t i = 0; i < str->Members().Length(); i++) {
+ operands.push_back(Constant(constant->Index(i)));
+ }
+ module_.PushType(spv::Op::OpConstantComposite, operands);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ return id;
+ });
}
- // Emit the instruction.
- current_function_.push_inst(op, operands);
-}
-
-void Printer::EmitConstruct(core::ir::Construct* construct) {
- // If there is just a single argument with the same type as the result, this is an identity
- // constructor and we can just pass through the ID of the argument.
- if (construct->Args().Length() == 1 &&
- construct->Result()->Type() == construct->Args()[0]->Type()) {
- values_.Add(construct->Result(), Value(construct->Args()[0]));
- return;
+ /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
+ /// @param type the type to get the ID for
+ /// @returns the result ID of the OpConstantNull instruction
+ uint32_t ConstantNull(const core::type::Type* type) {
+ return constant_nulls_.GetOrCreate(type, [&] {
+ auto id = module_.NextId();
+ module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
+ return id;
+ });
}
- OperandList operands = {Type(construct->Result()->Type()), Value(construct)};
- for (auto* arg : construct->Args()) {
- operands.push_back(Value(arg));
- }
- current_function_.push_inst(spv::Op::OpCompositeConstruct, std::move(operands));
-}
-
-void Printer::EmitConvert(core::ir::Convert* convert) {
- auto* res_ty = convert->Result()->Type();
- auto* arg_ty = convert->Args()[0]->Type();
-
- OperandList operands = {Type(convert->Result()->Type()), Value(convert)};
- for (auto* arg : convert->Args()) {
- operands.push_back(Value(arg));
+ /// Get the result ID of the OpUndef instruction with type `ty`, emitting it if necessary.
+ /// @param type the type of the undef value
+ /// @returns the result ID of the instruction
+ uint32_t Undef(const core::type::Type* type) {
+ return undef_values_.GetOrCreate(type, [&] {
+ auto id = module_.NextId();
+ module_.PushType(spv::Op::OpUndef, {Type(type), id});
+ return id;
+ });
}
- spv::Op op = spv::Op::Max;
- if (res_ty->is_signed_integer_scalar_or_vector() && arg_ty->is_float_scalar_or_vector()) {
- // float to signed int.
- op = spv::Op::OpConvertFToS;
- } else if (res_ty->is_unsigned_integer_scalar_or_vector() &&
- arg_ty->is_float_scalar_or_vector()) {
- // float to unsigned int.
- op = spv::Op::OpConvertFToU;
- } else if (res_ty->is_float_scalar_or_vector() &&
- arg_ty->is_signed_integer_scalar_or_vector()) {
- // signed int to float.
- op = spv::Op::OpConvertSToF;
- } else if (res_ty->is_float_scalar_or_vector() &&
- arg_ty->is_unsigned_integer_scalar_or_vector()) {
- // unsigned int to float.
- op = spv::Op::OpConvertUToF;
- } else if (res_ty->is_float_scalar_or_vector() && arg_ty->is_float_scalar_or_vector() &&
- res_ty->Size() != arg_ty->Size()) {
- // float to float (different bitwidth).
- op = spv::Op::OpFConvert;
- } else if (res_ty->is_integer_scalar_or_vector() && arg_ty->is_integer_scalar_or_vector() &&
- res_ty->Size() == arg_ty->Size()) {
- // int to int (same bitwidth, different signedness).
- op = spv::Op::OpBitcast;
- } else if (res_ty->is_bool_scalar_or_vector()) {
- if (arg_ty->is_integer_scalar_or_vector()) {
- // int to bool.
- op = spv::Op::OpINotEqual;
- } else {
- // float to bool.
- op = spv::Op::OpFUnordNotEqual;
- }
- operands.push_back(ConstantNull(arg_ty));
- } else if (arg_ty->is_bool_scalar_or_vector()) {
- // Select between constant one and zero, splatting them to vectors if necessary.
- core::ir::Constant* one = nullptr;
- core::ir::Constant* zero = nullptr;
- Switch(
- res_ty->DeepestElement(), //
- [&](const core::type::F32*) {
- one = b_.Constant(1_f);
- zero = b_.Constant(0_f);
- },
- [&](const core::type::F16*) {
- one = b_.Constant(1_h);
- zero = b_.Constant(0_h);
- },
- [&](const core::type::I32*) {
- one = b_.Constant(1_i);
- zero = b_.Constant(0_i);
- },
- [&](const core::type::U32*) {
- one = b_.Constant(1_u);
- zero = b_.Constant(0_u);
+ /// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
+ /// @param ty the type to get the ID for
+ /// @returns the result ID of the type
+ uint32_t Type(const core::type::Type* ty) {
+ ty = DedupType(ty, ir_.Types());
+ return types_.GetOrCreate(ty, [&] {
+ auto id = module_.NextId();
+ Switch(
+ ty, //
+ [&](const core::type::Void*) { module_.PushType(spv::Op::OpTypeVoid, {id}); },
+ [&](const core::type::Bool*) { module_.PushType(spv::Op::OpTypeBool, {id}); },
+ [&](const core::type::I32*) {
+ module_.PushType(spv::Op::OpTypeInt, {id, 32u, 1u});
+ },
+ [&](const core::type::U32*) {
+ module_.PushType(spv::Op::OpTypeInt, {id, 32u, 0u});
+ },
+ [&](const core::type::F32*) {
+ module_.PushType(spv::Op::OpTypeFloat, {id, 32u});
+ },
+ [&](const core::type::F16*) {
+ module_.PushCapability(SpvCapabilityFloat16);
+ module_.PushCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
+ module_.PushCapability(SpvCapabilityStorageBuffer16BitAccess);
+ module_.PushCapability(SpvCapabilityStorageInputOutput16);
+ module_.PushType(spv::Op::OpTypeFloat, {id, 16u});
+ },
+ [&](const core::type::Vector* vec) {
+ module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()});
+ },
+ [&](const core::type::Matrix* mat) {
+ module_.PushType(spv::Op::OpTypeMatrix,
+ {id, Type(mat->ColumnType()), mat->columns()});
+ },
+ [&](const core::type::Array* arr) {
+ if (arr->ConstantCount()) {
+ auto* count = b_.ConstantValue(u32(arr->ConstantCount().value()));
+ module_.PushType(spv::Op::OpTypeArray,
+ {id, Type(arr->ElemType()), Constant(count)});
+ } else {
+ TINT_ASSERT(arr->Count()->Is<core::type::RuntimeArrayCount>());
+ module_.PushType(spv::Op::OpTypeRuntimeArray, {id, Type(arr->ElemType())});
+ }
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationArrayStride), arr->Stride()});
+ },
+ [&](const core::type::Pointer* ptr) {
+ module_.PushType(spv::Op::OpTypePointer,
+ {id, U32Operand(StorageClass(ptr->AddressSpace())),
+ Type(ptr->StoreType())});
+ },
+ [&](const core::type::Struct* str) { EmitStructType(id, str); },
+ [&](const core::type::Texture* tex) { EmitTextureType(id, tex); },
+ [&](const core::type::Sampler*) { module_.PushType(spv::Op::OpTypeSampler, {id}); },
+ [&](const type::SampledImage* s) {
+ module_.PushType(spv::Op::OpTypeSampledImage, {id, Type(s->Image())});
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ return id;
+ });
+ }
+
+ /// Get the result ID of the instruction result `value`, emitting its instruction if necessary.
+ /// @param inst the instruction to get the ID for
+ /// @returns the result ID of the instruction
+ uint32_t Value(core::ir::Instruction* inst) { return Value(inst->Result()); }
+
+ /// Get the result ID of the value `value`, emitting its instruction if necessary.
+ /// @param value the value to get the ID for
+ /// @returns the result ID of the value
+ uint32_t Value(core::ir::Value* value) {
+ return Switch(
+ value, //
+ [&](core::ir::Constant* constant) { return Constant(constant); },
+ [&](core::ir::Value*) {
+ return values_.GetOrCreate(value, [&] { return module_.NextId(); });
});
- TINT_ASSERT_OR_RETURN(one && zero);
-
- if (auto* vec = res_ty->As<core::type::Vector>()) {
- // Splat the scalars into vectors.
- one = b_.Splat(vec, one, vec->Width());
- zero = b_.Splat(vec, zero, vec->Width());
- }
-
- op = spv::Op::OpSelect;
- operands.push_back(Constant(b_.ConstantValue(one)));
- operands.push_back(Constant(b_.ConstantValue(zero)));
- } else {
- TINT_ICE() << "unhandled convert instruction";
}
- current_function_.push_inst(op, std::move(operands));
-}
-
-void Printer::EmitLoad(core::ir::Load* load) {
- current_function_.push_inst(spv::Op::OpLoad,
- {Type(load->Result()->Type()), Value(load), Value(load->From())});
-}
-
-void Printer::EmitLoadVectorElement(core::ir::LoadVectorElement* load) {
- auto* vec_ptr_ty = load->From()->Type()->As<core::type::Pointer>();
- auto* el_ty = load->Result()->Type();
- auto* el_ptr_ty = ir_.Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
- auto el_ptr_id = module_.NextId();
- current_function_.push_inst(
- spv::Op::OpAccessChain,
- {Type(el_ptr_ty), el_ptr_id, Value(load->From()), Value(load->Index())});
- current_function_.push_inst(spv::Op::OpLoad,
- {Type(load->Result()->Type()), Value(load), el_ptr_id});
-}
-
-void Printer::EmitLoop(core::ir::Loop* loop) {
- auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
- auto body_label = Label(loop->Body());
- auto continuing_label = Label(loop->Continuing());
-
- auto header_label = module_.NextId();
- TINT_SCOPED_ASSIGNMENT(loop_header_label_, header_label);
-
- auto merge_label = GetMergeLabel(loop);
- TINT_SCOPED_ASSIGNMENT(loop_merge_label_, merge_label);
-
- if (init_label != 0) {
- // Emit the loop initializer.
- current_function_.push_inst(spv::Op::OpBranch, {init_label});
- EmitBlock(loop->Initializer());
- } else {
- // No initializer. Branch to body.
- current_function_.push_inst(spv::Op::OpBranch, {header_label});
+ /// Get the ID of the label for `block`.
+ /// @param block the block to get the label ID for
+ /// @returns the ID of the block's label
+ uint32_t Label(core::ir::Block* block) {
+ return block_labels_.GetOrCreate(block, [&] { return module_.NextId(); });
}
- // Emit the loop body header, which contains the OpLoopMerge and OpPhis.
- // This then unconditionally branches to body_label
- current_function_.push_inst(spv::Op::OpLabel, {header_label});
- EmitIncomingPhis(loop->Body());
- current_function_.push_inst(
- spv::Op::OpLoopMerge, {merge_label, continuing_label, U32Operand(SpvLoopControlMaskNone)});
- current_function_.push_inst(spv::Op::OpBranch, {body_label});
+ /// Emit a struct type.
+ /// @param id the result ID to use
+ /// @param str the struct type to emit
+ void EmitStructType(uint32_t id, const core::type::Struct* str) {
+ // Helper to return `type` or a potentially nested array element type within `type` as a
+ // matrix type, or nullptr if no such matrix type is present.
+ auto get_nested_matrix_type = [&](const core::type::Type* type) {
+ while (auto* arr = type->As<core::type::Array>()) {
+ type = arr->ElemType();
+ }
+ return type->As<core::type::Matrix>();
+ };
- // Emit the loop body
- current_function_.push_inst(spv::Op::OpLabel, {body_label});
- EmitBlockInstructions(loop->Body());
+ OperandList operands = {id};
+ for (auto* member : str->Members()) {
+ operands.push_back(Type(member->Type()));
- // Emit the loop continuing block.
- if (loop->Continuing()->HasTerminator()) {
- EmitBlock(loop->Continuing());
- } else {
- // We still need to emit a continuing block with a back-edge, even if it is unreachable.
- current_function_.push_inst(spv::Op::OpLabel, {continuing_label});
- current_function_.push_inst(spv::Op::OpBranch, {header_label});
- }
+ // Generate struct member offset decoration.
+ module_.PushAnnot(
+ spv::Op::OpMemberDecorate,
+ {operands[0], member->Index(), U32Operand(SpvDecorationOffset), member->Offset()});
- // Emit the loop merge block.
- current_function_.push_inst(spv::Op::OpLabel, {merge_label});
+ // Emit matrix layout decorations if necessary.
+ if (auto* matrix_type = get_nested_matrix_type(member->Type())) {
+ const uint32_t effective_row_count = (matrix_type->rows() == 2) ? 2 : 4;
+ module_.PushAnnot(spv::Op::OpMemberDecorate,
+ {id, member->Index(), U32Operand(SpvDecorationColMajor)});
+ module_.PushAnnot(spv::Op::OpMemberDecorate,
+ {id, member->Index(), U32Operand(SpvDecorationMatrixStride),
+ Operand(effective_row_count * matrix_type->type()->Size())});
+ }
- // Emit the OpPhis for the ExitLoops
- EmitExitPhis(loop);
-}
-
-void Printer::EmitSwitch(core::ir::Switch* swtch) {
- // Find the default selector. There must be exactly one.
- uint32_t default_label = 0u;
- for (auto& c : swtch->Cases()) {
- for (auto& sel : c.selectors) {
- if (sel.IsDefault()) {
- default_label = Label(c.Block());
+ if (member->Name().IsValid()) {
+ module_.PushDebug(spv::Op::OpMemberName,
+ {operands[0], member->Index(), Operand(member->Name().Name())});
}
}
- }
- TINT_ASSERT(default_label != 0u);
+ module_.PushType(spv::Op::OpTypeStruct, std::move(operands));
- // Build the operands to the OpSwitch instruction.
- OperandList switch_operands = {Value(swtch->Condition()), default_label};
- for (auto& c : swtch->Cases()) {
- auto label = Label(c.Block());
- for (auto& sel : c.selectors) {
- if (sel.IsDefault()) {
+ // Add a Block decoration if necessary.
+ if (str->StructFlags().Contains(core::type::StructFlag::kBlock)) {
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBlock)});
+ }
+
+ if (str->Name().IsValid()) {
+ module_.PushDebug(spv::Op::OpName, {operands[0], Operand(str->Name().Name())});
+ }
+ }
+
+ /// Emit a texture type.
+ /// @param id the result ID to use
+ /// @param texture the texture type to emit
+ void EmitTextureType(uint32_t id, const core::type::Texture* texture) {
+ uint32_t sampled_type = Switch(
+ texture, //
+ [&](const core::type::SampledTexture* t) { return Type(t->type()); },
+ [&](const core::type::MultisampledTexture* t) { return Type(t->type()); },
+ [&](const core::type::StorageTexture* t) { return Type(t->type()); }, //
+ TINT_ICE_ON_NO_MATCH);
+
+ uint32_t dim = SpvDimMax;
+ uint32_t array = 0u;
+ switch (texture->dim()) {
+ case core::type::TextureDimension::kNone: {
+ break;
+ }
+ case core::type::TextureDimension::k1d: {
+ dim = SpvDim1D;
+ if (texture->Is<core::type::SampledTexture>()) {
+ module_.PushCapability(SpvCapabilitySampled1D);
+ } else if (texture->Is<core::type::StorageTexture>()) {
+ module_.PushCapability(SpvCapabilityImage1D);
+ }
+ break;
+ }
+ case core::type::TextureDimension::k2d: {
+ dim = SpvDim2D;
+ break;
+ }
+ case core::type::TextureDimension::k2dArray: {
+ dim = SpvDim2D;
+ array = 1u;
+ break;
+ }
+ case core::type::TextureDimension::k3d: {
+ dim = SpvDim3D;
+ break;
+ }
+ case core::type::TextureDimension::kCube: {
+ dim = SpvDimCube;
+ break;
+ }
+ case core::type::TextureDimension::kCubeArray: {
+ dim = SpvDimCube;
+ array = 1u;
+ if (texture->Is<core::type::SampledTexture>()) {
+ module_.PushCapability(SpvCapabilitySampledCubeArray);
+ }
+ break;
+ }
+ }
+
+ // The Vulkan spec says: The "Depth" operand of OpTypeImage is ignored.
+ // In SPIRV, 0 means not depth, 1 means depth, and 2 means unknown.
+ // Using anything other than 0 is problematic on various Vulkan drivers.
+ uint32_t depth = 0u;
+
+ uint32_t ms = 0u;
+ if (texture->Is<core::type::MultisampledTexture>()) {
+ ms = 1u;
+ }
+
+ uint32_t sampled = 2u;
+ if (texture->IsAnyOf<core::type::MultisampledTexture, core::type::SampledTexture>()) {
+ sampled = 1u;
+ }
+
+ uint32_t format = SpvImageFormat_::SpvImageFormatUnknown;
+ if (auto* st = texture->As<core::type::StorageTexture>()) {
+ format = TexelFormat(st->texel_format());
+ }
+
+ module_.PushType(spv::Op::OpTypeImage,
+ {id, sampled_type, dim, depth, array, ms, sampled, format});
+ }
+
+ /// Emit a function.
+ /// @param func the function to emit
+ void EmitFunction(core::ir::Function* func) {
+ auto id = Value(func);
+
+ // Emit the function name.
+ module_.PushDebug(spv::Op::OpName, {id, Operand(ir_.NameOf(func).Name())});
+
+ // Emit OpEntryPoint and OpExecutionMode declarations if needed.
+ if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
+ EmitEntryPoint(func, id);
+ }
+
+ // Get the ID for the return type.
+ auto return_type_id = Type(func->ReturnType());
+
+ FunctionType function_type{return_type_id, {}};
+ InstructionList params;
+
+ // Generate function parameter declarations and add their type IDs to the function
+ // signature.
+ for (auto* param : func->Params()) {
+ auto param_type_id = Type(param->Type());
+ auto param_id = Value(param);
+ params.push_back(Instruction(spv::Op::OpFunctionParameter, {param_type_id, param_id}));
+ function_type.param_type_ids.Push(param_type_id);
+ if (auto name = ir_.NameOf(param)) {
+ module_.PushDebug(spv::Op::OpName, {param_id, Operand(name.Name())});
+ }
+ }
+
+ // Get the ID for the function type (creating it if needed).
+ auto function_type_id = function_types_.GetOrCreate(function_type, [&] {
+ auto func_ty_id = module_.NextId();
+ OperandList operands = {func_ty_id, return_type_id};
+ operands.insert(operands.end(), function_type.param_type_ids.begin(),
+ function_type.param_type_ids.end());
+ module_.PushType(spv::Op::OpTypeFunction, operands);
+ return func_ty_id;
+ });
+
+ // Declare the function.
+ auto decl = Instruction{
+ spv::Op::OpFunction,
+ {return_type_id, id, U32Operand(SpvFunctionControlMaskNone), function_type_id}};
+
+ // Create a function that we will add instructions to.
+ auto entry_block = module_.NextId();
+ current_function_ = Function(decl, entry_block, std::move(params));
+ TINT_DEFER(current_function_ = Function());
+
+ // Emit the body of the function.
+ EmitBlock(func->Block());
+
+ // Add the function to the module.
+ module_.PushFunction(current_function_);
+ }
+
+ /// Emit entry point declarations for a function.
+ /// @param func the function to emit entry point declarations for
+ /// @param id the result ID of the function declaration
+ void EmitEntryPoint(core::ir::Function* func, uint32_t id) {
+ SpvExecutionModel stage = SpvExecutionModelMax;
+ switch (func->Stage()) {
+ case core::ir::Function::PipelineStage::kCompute: {
+ stage = SpvExecutionModelGLCompute;
+ module_.PushExecutionMode(
+ spv::Op::OpExecutionMode,
+ {id, U32Operand(SpvExecutionModeLocalSize), func->WorkgroupSize()->at(0),
+ func->WorkgroupSize()->at(1), func->WorkgroupSize()->at(2)});
+ break;
+ }
+ case core::ir::Function::PipelineStage::kFragment: {
+ stage = SpvExecutionModelFragment;
+ module_.PushExecutionMode(spv::Op::OpExecutionMode,
+ {id, U32Operand(SpvExecutionModeOriginUpperLeft)});
+ break;
+ }
+ case core::ir::Function::PipelineStage::kVertex: {
+ stage = SpvExecutionModelVertex;
+ break;
+ }
+ case core::ir::Function::PipelineStage::kUndefined:
+ TINT_ICE() << "undefined pipeline stage for entry point";
+ return;
+ }
+
+ OperandList operands = {U32Operand(stage), id, ir_.NameOf(func).Name()};
+
+ // Add the list of all referenced shader IO variables.
+ for (auto* global : *ir_.root_block) {
+ auto* var = global->As<core::ir::Var>();
+ if (!var) {
continue;
}
- switch_operands.push_back(sel.val->Value()->ValueAs<uint32_t>());
- switch_operands.push_back(label);
- }
- }
- uint32_t merge_label = GetMergeLabel(swtch);
- TINT_SCOPED_ASSIGNMENT(switch_merge_label_, merge_label);
-
- // Emit the OpSelectionMerge and OpSwitch instructions.
- current_function_.push_inst(spv::Op::OpSelectionMerge,
- {merge_label, U32Operand(SpvSelectionControlMaskNone)});
- current_function_.push_inst(spv::Op::OpSwitch, switch_operands);
-
- // Emit the cases.
- for (auto& c : swtch->Cases()) {
- EmitBlock(c.Block());
- }
-
- // Emit the switch merge block.
- current_function_.push_inst(spv::Op::OpLabel, {merge_label});
-
- // Emit the OpPhis for the ExitSwitches
- EmitExitPhis(swtch);
-}
-
-void Printer::EmitSwizzle(core::ir::Swizzle* swizzle) {
- auto id = Value(swizzle);
- auto obj = Value(swizzle->Object());
- OperandList operands = {Type(swizzle->Result()->Type()), id, obj, obj};
- for (auto idx : swizzle->Indices()) {
- operands.push_back(idx);
- }
- current_function_.push_inst(spv::Op::OpVectorShuffle, operands);
-}
-
-void Printer::EmitStore(core::ir::Store* store) {
- current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
-}
-
-void Printer::EmitStoreVectorElement(core::ir::StoreVectorElement* store) {
- auto* vec_ptr_ty = store->To()->Type()->As<core::type::Pointer>();
- auto* el_ty = store->Value()->Type();
- auto* el_ptr_ty = ir_.Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
- auto el_ptr_id = module_.NextId();
- current_function_.push_inst(
- spv::Op::OpAccessChain,
- {Type(el_ptr_ty), el_ptr_id, Value(store->To()), Value(store->Index())});
- current_function_.push_inst(spv::Op::OpStore, {el_ptr_id, Value(store->Value())});
-}
-
-void Printer::EmitUnary(core::ir::Unary* unary) {
- auto id = Value(unary);
- auto* ty = unary->Result()->Type();
- spv::Op op = spv::Op::Max;
- switch (unary->Op()) {
- case core::ir::UnaryOp::kComplement:
- op = spv::Op::OpNot;
- break;
- case core::ir::UnaryOp::kNegation:
- if (ty->is_float_scalar_or_vector()) {
- op = spv::Op::OpFNegate;
- } else if (ty->is_signed_integer_scalar_or_vector()) {
- op = spv::Op::OpSNegate;
+ auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
+ if (!(ptr->AddressSpace() == core::AddressSpace::kIn ||
+ ptr->AddressSpace() == core::AddressSpace::kOut)) {
+ continue;
}
- break;
- }
- current_function_.push_inst(op, {Type(ty), id, Value(unary->Val())});
-}
-void Printer::EmitUserCall(core::ir::UserCall* call) {
- auto id = Value(call);
- OperandList operands = {Type(call->Result()->Type()), id, Value(call->Target())};
- for (auto* arg : call->Args()) {
- operands.push_back(Value(arg));
- }
- current_function_.push_inst(spv::Op::OpFunctionCall, operands);
-}
-
-void Printer::EmitIOAttributes(uint32_t id,
- const core::ir::IOAttributes& attrs,
- core::AddressSpace addrspace) {
- if (attrs.location) {
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationLocation), *attrs.location});
- }
- if (attrs.index) {
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationIndex), *attrs.index});
- }
- if (attrs.interpolation) {
- switch (attrs.interpolation->type) {
- case core::InterpolationType::kLinear:
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationNoPerspective)});
- break;
- case core::InterpolationType::kFlat:
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationFlat)});
- break;
- case core::InterpolationType::kPerspective:
- case core::InterpolationType::kUndefined:
- break;
- }
- switch (attrs.interpolation->sampling) {
- case core::InterpolationSampling::kCentroid:
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationCentroid)});
- break;
- case core::InterpolationSampling::kSample:
- module_.PushCapability(SpvCapabilitySampleRateShading);
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationSample)});
- break;
- case core::InterpolationSampling::kCenter:
- case core::InterpolationSampling::kUndefined:
- break;
- }
- }
- if (attrs.builtin) {
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBuiltIn),
- Builtin(*attrs.builtin, addrspace)});
- }
- if (attrs.invariant) {
- module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationInvariant)});
- }
-}
-
-void Printer::EmitVar(core::ir::Var* var) {
- auto id = Value(var);
- auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
- auto* store_ty = ptr->StoreType();
- auto ty = Type(ptr);
-
- switch (ptr->AddressSpace()) {
- case core::AddressSpace::kFunction: {
- TINT_ASSERT(current_function_);
- if (var->Initializer()) {
- current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)});
- current_function_.push_inst(spv::Op::OpStore, {id, Value(var->Initializer())});
- } else {
- current_function_.push_var(
- {ty, id, U32Operand(SpvStorageClassFunction), ConstantNull(store_ty)});
- }
- break;
- }
- case core::AddressSpace::kIn: {
- TINT_ASSERT(!current_function_);
- module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassInput)});
- EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kIn);
- break;
- }
- case core::AddressSpace::kPrivate: {
- TINT_ASSERT(!current_function_);
- OperandList operands = {ty, id, U32Operand(SpvStorageClassPrivate)};
- if (var->Initializer()) {
- TINT_ASSERT(var->Initializer()->Is<core::ir::Constant>());
- operands.push_back(Value(var->Initializer()));
- } else {
- operands.push_back(ConstantNull(store_ty));
- }
- module_.PushType(spv::Op::OpVariable, operands);
- break;
- }
- case core::AddressSpace::kPushConstant: {
- TINT_ASSERT(!current_function_);
- module_.PushType(spv::Op::OpVariable,
- {ty, id, U32Operand(SpvStorageClassPushConstant)});
- break;
- }
- case core::AddressSpace::kOut: {
- TINT_ASSERT(!current_function_);
- module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassOutput)});
- EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kOut);
- break;
- }
- case core::AddressSpace::kHandle:
- case core::AddressSpace::kStorage:
- case core::AddressSpace::kUniform: {
- TINT_ASSERT(!current_function_);
- module_.PushType(spv::Op::OpVariable,
- {ty, id, U32Operand(StorageClass(ptr->AddressSpace()))});
- auto bp = var->BindingPoint().value();
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationDescriptorSet), bp.group});
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationBinding), bp.binding});
-
- // Add NonReadable and NonWritable decorations to storage textures and buffers.
- auto* st = store_ty->As<core::type::StorageTexture>();
- if (st || store_ty->Is<core::type::Struct>()) {
- auto access = st ? st->access() : ptr->Access();
- if (access == core::Access::kRead) {
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationNonWritable)});
- } else if (access == core::Access::kWrite) {
- module_.PushAnnot(spv::Op::OpDecorate,
- {id, U32Operand(SpvDecorationNonReadable)});
+ // Determine if this IO variable is used by the entry point.
+ bool used = false;
+ for (const auto& use : var->Result()->Usages()) {
+ auto* block = use.instruction->Block();
+ while (block->Parent()) {
+ block = block->Parent()->Block();
+ }
+ if (block == func->Block()) {
+ used = true;
+ break;
}
}
- break;
- }
- case core::AddressSpace::kWorkgroup: {
- TINT_ASSERT(!current_function_);
- OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
- if (zero_init_workgroup_memory_) {
- // If requested, use the VK_KHR_zero_initialize_workgroup_memory to zero-initialize
- // the workgroup variable using an null constant initializer.
- operands.push_back(ConstantNull(store_ty));
+ if (!used) {
+ continue;
}
- module_.PushType(spv::Op::OpVariable, operands);
- break;
+ operands.push_back(Value(var));
+
+ // Add the `DepthReplacing` execution mode if `frag_depth` is used.
+ if (var->Attributes().builtin == core::BuiltinValue::kFragDepth) {
+ module_.PushExecutionMode(spv::Op::OpExecutionMode,
+ {id, U32Operand(SpvExecutionModeDepthReplacing)});
+ }
}
- default: {
- TINT_ICE() << "unimplemented variable address space " << ptr->AddressSpace();
+
+ module_.PushEntryPoint(spv::Op::OpEntryPoint, operands);
+ }
+
+ /// Emit the root block.
+ /// @param root_block the root block to emit
+ void EmitRootBlock(core::ir::Block* root_block) {
+ for (auto* inst : *root_block) {
+ Switch(
+ inst, //
+ [&](core::ir::Var* v) { return EmitVar(v); }, //
+ TINT_ICE_ON_NO_MATCH);
}
}
- // Set the name if present.
- if (auto name = ir_.NameOf(var)) {
- module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
- }
-}
-
-void Printer::EmitLet(core::ir::Let* let) {
- auto id = Value(let->Value());
- values_.Add(let->Result(), id);
-}
-
-void Printer::EmitExitPhis(core::ir::ControlInstruction* inst) {
- struct Branch {
- uint32_t label = 0;
- core::ir::Value* value = nullptr;
- bool operator<(const Branch& other) const { return label < other.label; }
- };
-
- auto results = inst->Results();
- for (size_t index = 0; index < results.Length(); index++) {
- auto* result = results[index];
- auto* ty = result->Type();
-
- Vector<Branch, 8> branches;
- branches.Reserve(inst->Exits().Count());
- for (auto& exit : inst->Exits()) {
- branches.Push(Branch{GetTerminatorBlockLabel(exit), exit->Args()[index]});
+ /// Emit a block, including the initial OpLabel, OpPhis and instructions.
+ /// @param block the block to emit
+ void EmitBlock(core::ir::Block* block) {
+ // Emit the label.
+ // Skip if this is the function's entry block, as it will be emitted by the function object.
+ if (!current_function_.instructions().empty()) {
+ current_function_.push_inst(spv::Op::OpLabel, {Label(block)});
}
- branches.Sort(); // Sort the branches by label to ensure deterministic output
- OperandList ops{Type(ty), Value(result)};
- for (auto& branch : branches) {
- if (branch.value == nullptr) {
- ops.push_back(Undef(ty));
+ // If there are no instructions in the block, it's a dead end, so we shouldn't be able to
+ // get here to begin with.
+ if (block->IsEmpty()) {
+ current_function_.push_inst(spv::Op::OpUnreachable, {});
+ return;
+ }
+
+ if (auto* mib = block->As<core::ir::MultiInBlock>()) {
+ // Emit all OpPhi nodes for incoming branches to block.
+ EmitIncomingPhis(mib);
+ }
+
+ // Emit the block's statements.
+ EmitBlockInstructions(block);
+ }
+
+ /// Emit all OpPhi nodes for incoming branches to @p block.
+ /// @param block the block to emit the OpPhis for
+ void EmitIncomingPhis(core::ir::MultiInBlock* block) {
+ // Emit Phi nodes for all the incoming block parameters
+ for (size_t param_idx = 0; param_idx < block->Params().Length(); param_idx++) {
+ auto* param = block->Params()[param_idx];
+ OperandList ops{Type(param->Type()), Value(param)};
+
+ for (auto* incoming : block->InboundSiblingBranches()) {
+ auto* arg = incoming->Args()[param_idx];
+ ops.push_back(Value(arg));
+ ops.push_back(GetTerminatorBlockLabel(incoming));
+ }
+
+ current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
+ }
+ }
+
+ /// Emit all instructions of @p block.
+ /// @param block the block's instructions to emit
+ void EmitBlockInstructions(core::ir::Block* block) {
+ for (auto* inst : *block) {
+ Switch(
+ inst, //
+ [&](core::ir::Access* a) { EmitAccess(a); }, //
+ [&](core::ir::Binary* b) { EmitBinary(b); }, //
+ [&](core::ir::Bitcast* b) { EmitBitcast(b); }, //
+ [&](core::ir::CoreBuiltinCall* b) { EmitCoreBuiltinCall(b); }, //
+ [&](spirv::ir::BuiltinCall* b) { EmitSpirvBuiltinCall(b); }, //
+ [&](core::ir::Construct* c) { EmitConstruct(c); }, //
+ [&](core::ir::Convert* c) { EmitConvert(c); }, //
+ [&](core::ir::Load* l) { EmitLoad(l); }, //
+ [&](core::ir::LoadVectorElement* l) { EmitLoadVectorElement(l); }, //
+ [&](core::ir::Loop* l) { EmitLoop(l); }, //
+ [&](core::ir::Switch* sw) { EmitSwitch(sw); }, //
+ [&](core::ir::Swizzle* s) { EmitSwizzle(s); }, //
+ [&](core::ir::Store* s) { EmitStore(s); }, //
+ [&](core::ir::StoreVectorElement* s) { EmitStoreVectorElement(s); }, //
+ [&](core::ir::UserCall* c) { EmitUserCall(c); }, //
+ [&](core::ir::Unary* u) { EmitUnary(u); }, //
+ [&](core::ir::Var* v) { EmitVar(v); }, //
+ [&](core::ir::Let* l) { EmitLet(l); }, //
+ [&](core::ir::If* i) { EmitIf(i); }, //
+ [&](core::ir::Terminator* t) { EmitTerminator(t); }, //
+ TINT_ICE_ON_NO_MATCH);
+
+ // Set the name for the SPIR-V result ID if provided in the module.
+ if (inst->Result() && !inst->Is<core::ir::Var>()) {
+ if (auto name = ir_.NameOf(inst)) {
+ module_.PushDebug(spv::Op::OpName, {Value(inst), Operand(name.Name())});
+ }
+ }
+ }
+
+ if (block->IsEmpty()) {
+ // If the last emitted instruction is not a branch, then this should be unreachable.
+ current_function_.push_inst(spv::Op::OpUnreachable, {});
+ }
+ }
+
+ /// Emit a terminator instruction.
+ /// @param t the terminator instruction to emit
+ void EmitTerminator(core::ir::Terminator* t) {
+ tint::Switch( //
+ t, //
+ [&](core::ir::Return*) {
+ if (!t->Args().IsEmpty()) {
+ TINT_ASSERT(t->Args().Length() == 1u);
+ OperandList operands;
+ operands.push_back(Value(t->Args()[0]));
+ current_function_.push_inst(spv::Op::OpReturnValue, operands);
+ } else {
+ current_function_.push_inst(spv::Op::OpReturn, {});
+ }
+ return;
+ },
+ [&](core::ir::BreakIf* breakif) {
+ current_function_.push_inst(spv::Op::OpBranchConditional,
+ {
+ Value(breakif->Condition()),
+ loop_merge_label_,
+ loop_header_label_,
+ });
+ },
+ [&](core::ir::Continue* cont) {
+ current_function_.push_inst(spv::Op::OpBranch, {Label(cont->Loop()->Continuing())});
+ },
+ [&](core::ir::ExitIf*) {
+ current_function_.push_inst(spv::Op::OpBranch, {if_merge_label_});
+ },
+ [&](core::ir::ExitLoop*) {
+ current_function_.push_inst(spv::Op::OpBranch, {loop_merge_label_});
+ },
+ [&](core::ir::ExitSwitch*) {
+ current_function_.push_inst(spv::Op::OpBranch, {switch_merge_label_});
+ },
+ [&](core::ir::NextIteration*) {
+ current_function_.push_inst(spv::Op::OpBranch, {loop_header_label_});
+ },
+ [&](core::ir::TerminateInvocation*) {
+ current_function_.push_inst(spv::Op::OpKill, {});
+ },
+ [&](core::ir::Unreachable*) {
+ current_function_.push_inst(spv::Op::OpUnreachable, {});
+ }, //
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ /// Emit an `if` flow node.
+ /// @param i the if node to emit
+ void EmitIf(core::ir::If* i) {
+ auto* true_block = i->True();
+ auto* false_block = i->False();
+
+ // Generate labels for the blocks. We emit the true or false block if it:
+ // 1. contains instructions other then the branch, or
+ // 2. branches somewhere instead of exiting the loop (e.g. return or break), or
+ // 3. the if returns a value
+ // Otherwise we skip them and branch straight to the merge block.
+ uint32_t merge_label = GetMergeLabel(i);
+ TINT_SCOPED_ASSIGNMENT(if_merge_label_, merge_label);
+
+ uint32_t true_label = merge_label;
+ uint32_t false_label = merge_label;
+ if (true_block->Length() > 1 || i->HasResults() ||
+ (true_block->HasTerminator() && !true_block->Terminator()->Is<core::ir::ExitIf>())) {
+ true_label = Label(true_block);
+ }
+ if (false_block->Length() > 1 || i->HasResults() ||
+ (false_block->HasTerminator() && !false_block->Terminator()->Is<core::ir::ExitIf>())) {
+ false_label = Label(false_block);
+ }
+
+ // Emit the OpSelectionMerge and OpBranchConditional instructions.
+ current_function_.push_inst(spv::Op::OpSelectionMerge,
+ {merge_label, U32Operand(SpvSelectionControlMaskNone)});
+ current_function_.push_inst(spv::Op::OpBranchConditional,
+ {Value(i->Condition()), true_label, false_label});
+
+ // Emit the `true` and `false` blocks, if they're not being skipped.
+ if (true_label != merge_label) {
+ EmitBlock(true_block);
+ }
+ if (false_label != merge_label) {
+ EmitBlock(false_block);
+ }
+
+ current_function_.push_inst(spv::Op::OpLabel, {merge_label});
+
+ // Emit the OpPhis for the ExitIfs
+ EmitExitPhis(i);
+ }
+
+ /// Emit an access instruction
+ /// @param access the access instruction to emit
+ void EmitAccess(core::ir::Access* access) {
+ auto* ty = access->Result()->Type();
+
+ auto id = Value(access);
+ OperandList operands = {Type(ty), id, Value(access->Object())};
+
+ if (ty->Is<core::type::Pointer>()) {
+ // Use OpAccessChain for accesses into pointer types.
+ for (auto* idx : access->Indices()) {
+ operands.push_back(Value(idx));
+ }
+ current_function_.push_inst(spv::Op::OpAccessChain, std::move(operands));
+ return;
+ }
+
+ // For non-pointer types, we assume that the indices are constants and use
+ // OpCompositeExtract. If we hit a non-constant index into a vector type, use
+ // OpVectorExtractDynamic for it.
+ auto* source_ty = access->Object()->Type();
+ for (auto* idx : access->Indices()) {
+ if (auto* constant = idx->As<core::ir::Constant>()) {
+ // Push the index to the chain and update the current type.
+ auto i = constant->Value()->ValueAs<u32>();
+ operands.push_back(i);
+ source_ty = source_ty->Element(i);
} else {
- ops.push_back(Value(branch.value));
+ // The VarForDynamicIndex transform ensures that only value types that are vectors
+ // will be dynamically indexed, as we can use OpVectorExtractDynamic for this case.
+ TINT_ASSERT(source_ty->Is<core::type::Vector>());
+
+ // If this wasn't the first access in the chain then emit the chain so far as an
+ // OpCompositeExtract, creating a new result ID for the resulting vector.
+ auto vec_id = Value(access->Object());
+ if (operands.size() > 3) {
+ vec_id = module_.NextId();
+ operands[0] = Type(source_ty);
+ operands[1] = vec_id;
+ current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
+ }
+
+ // Now emit the OpVectorExtractDynamic instruction.
+ operands = {Type(ty), id, vec_id, Value(idx)};
+ current_function_.push_inst(spv::Op::OpVectorExtractDynamic, std::move(operands));
+ return;
}
- ops.push_back(branch.label);
}
- current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
+ current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
}
-}
-uint32_t Printer::GetMergeLabel(core::ir::ControlInstruction* ci) {
- return merge_block_labels_.GetOrCreate(ci, [&] { return module_.NextId(); });
-}
+ /// Emit a binary instruction.
+ /// @param binary the binary instruction to emit
+ void EmitBinary(core::ir::Binary* binary) {
+ auto id = Value(binary);
+ auto lhs = Value(binary->LHS());
+ auto rhs = Value(binary->RHS());
+ auto* ty = binary->Result()->Type();
+ auto* lhs_ty = binary->LHS()->Type();
-uint32_t Printer::GetTerminatorBlockLabel(core::ir::Terminator* t) {
- // Walk backwards from `t` until we find a control instruction.
- auto* inst = t->prev;
- while (inst) {
- auto* prev = inst->prev;
- if (auto* ci = inst->As<core::ir::ControlInstruction>()) {
- // This is the last control instruction before `t`, so use its merge block label.
- return GetMergeLabel(ci);
+ // Determine the opcode.
+ spv::Op op = spv::Op::Max;
+ switch (binary->Op()) {
+ case core::ir::BinaryOp::kAdd: {
+ op = ty->is_integer_scalar_or_vector() ? spv::Op::OpIAdd : spv::Op::OpFAdd;
+ break;
+ }
+ case core::ir::BinaryOp::kDivide: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSDiv;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUDiv;
+ } else if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFDiv;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kMultiply: {
+ if (ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpIMul;
+ } else if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFMul;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kSubtract: {
+ op = ty->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
+ break;
+ }
+ case core::ir::BinaryOp::kModulo: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSRem;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUMod;
+ } else if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFRem;
+ }
+ break;
+ }
+
+ case core::ir::BinaryOp::kAnd: {
+ if (ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpBitwiseAnd;
+ } else if (ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalAnd;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kOr: {
+ if (ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpBitwiseOr;
+ } else if (ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalOr;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kXor: {
+ op = spv::Op::OpBitwiseXor;
+ break;
+ }
+
+ case core::ir::BinaryOp::kShiftLeft: {
+ op = spv::Op::OpShiftLeftLogical;
+ break;
+ }
+ case core::ir::BinaryOp::kShiftRight: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpShiftRightArithmetic;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpShiftRightLogical;
+ }
+ break;
+ }
+
+ case core::ir::BinaryOp::kEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpIEqual;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kNotEqual: {
+ if (lhs_ty->is_bool_scalar_or_vector()) {
+ op = spv::Op::OpLogicalNotEqual;
+ } else if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdNotEqual;
+ } else if (lhs_ty->is_integer_scalar_or_vector()) {
+ op = spv::Op::OpINotEqual;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kGreaterThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThan;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kGreaterThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdGreaterThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSGreaterThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUGreaterThanEqual;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kLessThan: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThan;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThan;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThan;
+ }
+ break;
+ }
+ case core::ir::BinaryOp::kLessThanEqual: {
+ if (lhs_ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFOrdLessThanEqual;
+ } else if (lhs_ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSLessThanEqual;
+ } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpULessThanEqual;
+ }
+ break;
+ }
}
- inst = prev;
+
+ // Emit the instruction.
+ current_function_.push_inst(op, {Type(ty), id, lhs, rhs});
}
- // There were no control instructions before `t`, so use the label of the parent block.
- return Label(t->Block());
+ /// Emit a bitcast instruction.
+ /// @param bitcast the bitcast instruction to emit
+ void EmitBitcast(core::ir::Bitcast* bitcast) {
+ auto* ty = bitcast->Result()->Type();
+ if (ty == bitcast->Val()->Type()) {
+ values_.Add(bitcast->Result(), Value(bitcast->Val()));
+ return;
+ }
+ current_function_.push_inst(spv::Op::OpBitcast,
+ {Type(ty), Value(bitcast), Value(bitcast->Val())});
+ }
+
+ /// Emit a builtin function call instruction.
+ /// @param builtin the builtin call instruction to emit
+ void EmitSpirvBuiltinCall(spirv::ir::BuiltinCall* builtin) {
+ auto id = Value(builtin);
+
+ spv::Op op = spv::Op::Max;
+ switch (builtin->Func()) {
+ case spirv::BuiltinFn::kArrayLength:
+ op = spv::Op::OpArrayLength;
+ break;
+ case spirv::BuiltinFn::kAtomicIadd:
+ op = spv::Op::OpAtomicIAdd;
+ break;
+ case spirv::BuiltinFn::kAtomicIsub:
+ op = spv::Op::OpAtomicISub;
+ break;
+ case spirv::BuiltinFn::kAtomicAnd:
+ op = spv::Op::OpAtomicAnd;
+ break;
+ case spirv::BuiltinFn::kAtomicCompareExchange:
+ op = spv::Op::OpAtomicCompareExchange;
+ break;
+ case spirv::BuiltinFn::kAtomicExchange:
+ op = spv::Op::OpAtomicExchange;
+ break;
+ case spirv::BuiltinFn::kAtomicLoad:
+ op = spv::Op::OpAtomicLoad;
+ break;
+ case spirv::BuiltinFn::kAtomicOr:
+ op = spv::Op::OpAtomicOr;
+ break;
+ case spirv::BuiltinFn::kAtomicSmax:
+ op = spv::Op::OpAtomicSMax;
+ break;
+ case spirv::BuiltinFn::kAtomicSmin:
+ op = spv::Op::OpAtomicSMin;
+ break;
+ case spirv::BuiltinFn::kAtomicStore:
+ op = spv::Op::OpAtomicStore;
+ break;
+ case spirv::BuiltinFn::kAtomicUmax:
+ op = spv::Op::OpAtomicUMax;
+ break;
+ case spirv::BuiltinFn::kAtomicUmin:
+ op = spv::Op::OpAtomicUMin;
+ break;
+ case spirv::BuiltinFn::kAtomicXor:
+ op = spv::Op::OpAtomicXor;
+ break;
+ case spirv::BuiltinFn::kDot:
+ op = spv::Op::OpDot;
+ break;
+ case spirv::BuiltinFn::kImageDrefGather:
+ op = spv::Op::OpImageDrefGather;
+ break;
+ case spirv::BuiltinFn::kImageFetch:
+ op = spv::Op::OpImageFetch;
+ break;
+ case spirv::BuiltinFn::kImageGather:
+ op = spv::Op::OpImageGather;
+ break;
+ case spirv::BuiltinFn::kImageQuerySize:
+ module_.PushCapability(SpvCapabilityImageQuery);
+ op = spv::Op::OpImageQuerySize;
+ break;
+ case spirv::BuiltinFn::kImageQuerySizeLod:
+ module_.PushCapability(SpvCapabilityImageQuery);
+ op = spv::Op::OpImageQuerySizeLod;
+ break;
+ case spirv::BuiltinFn::kImageRead:
+ op = spv::Op::OpImageRead;
+ break;
+ case spirv::BuiltinFn::kImageSampleImplicitLod:
+ op = spv::Op::OpImageSampleImplicitLod;
+ break;
+ case spirv::BuiltinFn::kImageSampleExplicitLod:
+ op = spv::Op::OpImageSampleExplicitLod;
+ break;
+ case spirv::BuiltinFn::kImageSampleDrefImplicitLod:
+ op = spv::Op::OpImageSampleDrefImplicitLod;
+ break;
+ case spirv::BuiltinFn::kImageSampleDrefExplicitLod:
+ op = spv::Op::OpImageSampleDrefExplicitLod;
+ break;
+ case spirv::BuiltinFn::kImageWrite:
+ op = spv::Op::OpImageWrite;
+ break;
+ case spirv::BuiltinFn::kMatrixTimesMatrix:
+ op = spv::Op::OpMatrixTimesMatrix;
+ break;
+ case spirv::BuiltinFn::kMatrixTimesScalar:
+ op = spv::Op::OpMatrixTimesScalar;
+ break;
+ case spirv::BuiltinFn::kMatrixTimesVector:
+ op = spv::Op::OpMatrixTimesVector;
+ break;
+ case spirv::BuiltinFn::kSampledImage:
+ op = spv::Op::OpSampledImage;
+ break;
+ case spirv::BuiltinFn::kSdot:
+ module_.PushExtension("SPV_KHR_integer_dot_product");
+ module_.PushCapability(SpvCapabilityDotProductKHR);
+ module_.PushCapability(SpvCapabilityDotProductInput4x8BitPackedKHR);
+ op = spv::Op::OpSDot;
+ break;
+ case spirv::BuiltinFn::kSelect:
+ op = spv::Op::OpSelect;
+ break;
+ case spirv::BuiltinFn::kUdot:
+ module_.PushExtension("SPV_KHR_integer_dot_product");
+ module_.PushCapability(SpvCapabilityDotProductKHR);
+ module_.PushCapability(SpvCapabilityDotProductInput4x8BitPackedKHR);
+ op = spv::Op::OpUDot;
+ break;
+ case spirv::BuiltinFn::kVectorTimesMatrix:
+ op = spv::Op::OpVectorTimesMatrix;
+ break;
+ case spirv::BuiltinFn::kVectorTimesScalar:
+ op = spv::Op::OpVectorTimesScalar;
+ break;
+ case spirv::BuiltinFn::kNone:
+ TINT_ICE() << "undefined spirv ir function";
+ return;
+ }
+
+ OperandList operands;
+ if (!builtin->Result()->Type()->Is<core::type::Void>()) {
+ operands = {Type(builtin->Result()->Type()), id};
+ }
+ for (auto* arg : builtin->Args()) {
+ operands.push_back(Value(arg));
+ }
+ current_function_.push_inst(op, operands);
+ }
+
+ /// Emit a builtin function call instruction.
+ /// @param builtin the builtin call instruction to emit
+ void EmitCoreBuiltinCall(core::ir::CoreBuiltinCall* builtin) {
+ auto* result_ty = builtin->Result()->Type();
+
+ if (builtin->Func() == core::BuiltinFn::kAbs &&
+ result_ty->is_unsigned_integer_scalar_or_vector()) {
+ // abs() is a no-op for unsigned integers.
+ values_.Add(builtin->Result(), Value(builtin->Args()[0]));
+ return;
+ }
+ if ((builtin->Func() == core::BuiltinFn::kAll ||
+ builtin->Func() == core::BuiltinFn::kAny) &&
+ builtin->Args()[0]->Type()->Is<core::type::Bool>()) {
+ // all() and any() are passthroughs for scalar arguments.
+ values_.Add(builtin->Result(), Value(builtin->Args()[0]));
+ return;
+ }
+
+ auto id = Value(builtin);
+
+ spv::Op op = spv::Op::Max;
+ OperandList operands = {Type(result_ty), id};
+
+ // Helper to set up the opcode and operand list for a GLSL extended instruction.
+ auto glsl_ext_inst = [&](enum GLSLstd450 inst) {
+ constexpr const char* kGLSLstd450 = "GLSL.std.450";
+ op = spv::Op::OpExtInst;
+ operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&] {
+ // Import the instruction set the first time it is requested.
+ auto import = module_.NextId();
+ module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)});
+ return import;
+ }));
+ operands.push_back(U32Operand(inst));
+ };
+
+ // Determine the opcode.
+ switch (builtin->Func()) {
+ case core::BuiltinFn::kAbs:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450FAbs);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SAbs);
+ }
+ break;
+ case core::BuiltinFn::kAll:
+ op = spv::Op::OpAll;
+ break;
+ case core::BuiltinFn::kAny:
+ op = spv::Op::OpAny;
+ break;
+ case core::BuiltinFn::kAcos:
+ glsl_ext_inst(GLSLstd450Acos);
+ break;
+ case core::BuiltinFn::kAcosh:
+ glsl_ext_inst(GLSLstd450Acosh);
+ break;
+ case core::BuiltinFn::kAsin:
+ glsl_ext_inst(GLSLstd450Asin);
+ break;
+ case core::BuiltinFn::kAsinh:
+ glsl_ext_inst(GLSLstd450Asinh);
+ break;
+ case core::BuiltinFn::kAtan:
+ glsl_ext_inst(GLSLstd450Atan);
+ break;
+ case core::BuiltinFn::kAtan2:
+ glsl_ext_inst(GLSLstd450Atan2);
+ break;
+ case core::BuiltinFn::kAtanh:
+ glsl_ext_inst(GLSLstd450Atanh);
+ break;
+ case core::BuiltinFn::kClamp:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450NClamp);
+ } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450UClamp);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SClamp);
+ }
+ break;
+ case core::BuiltinFn::kCeil:
+ glsl_ext_inst(GLSLstd450Ceil);
+ break;
+ case core::BuiltinFn::kCos:
+ glsl_ext_inst(GLSLstd450Cos);
+ break;
+ case core::BuiltinFn::kCosh:
+ glsl_ext_inst(GLSLstd450Cosh);
+ break;
+ case core::BuiltinFn::kCountOneBits:
+ op = spv::Op::OpBitCount;
+ break;
+ case core::BuiltinFn::kCross:
+ glsl_ext_inst(GLSLstd450Cross);
+ break;
+ case core::BuiltinFn::kDegrees:
+ glsl_ext_inst(GLSLstd450Degrees);
+ break;
+ case core::BuiltinFn::kDeterminant:
+ glsl_ext_inst(GLSLstd450Determinant);
+ break;
+ case core::BuiltinFn::kDistance:
+ glsl_ext_inst(GLSLstd450Distance);
+ break;
+ case core::BuiltinFn::kDpdx:
+ op = spv::Op::OpDPdx;
+ break;
+ case core::BuiltinFn::kDpdxCoarse:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdxCoarse;
+ break;
+ case core::BuiltinFn::kDpdxFine:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdxFine;
+ break;
+ case core::BuiltinFn::kDpdy:
+ op = spv::Op::OpDPdy;
+ break;
+ case core::BuiltinFn::kDpdyCoarse:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdyCoarse;
+ break;
+ case core::BuiltinFn::kDpdyFine:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdyFine;
+ break;
+ case core::BuiltinFn::kExp:
+ glsl_ext_inst(GLSLstd450Exp);
+ break;
+ case core::BuiltinFn::kExp2:
+ glsl_ext_inst(GLSLstd450Exp2);
+ break;
+ case core::BuiltinFn::kExtractBits:
+ op = result_ty->is_signed_integer_scalar_or_vector() ? spv::Op::OpBitFieldSExtract
+ : spv::Op::OpBitFieldUExtract;
+ break;
+ case core::BuiltinFn::kFaceForward:
+ glsl_ext_inst(GLSLstd450FaceForward);
+ break;
+ case core::BuiltinFn::kFloor:
+ glsl_ext_inst(GLSLstd450Floor);
+ break;
+ case core::BuiltinFn::kFma:
+ glsl_ext_inst(GLSLstd450Fma);
+ break;
+ case core::BuiltinFn::kFract:
+ glsl_ext_inst(GLSLstd450Fract);
+ break;
+ case core::BuiltinFn::kFrexp:
+ glsl_ext_inst(GLSLstd450FrexpStruct);
+ break;
+ case core::BuiltinFn::kFwidth:
+ op = spv::Op::OpFwidth;
+ break;
+ case core::BuiltinFn::kFwidthCoarse:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpFwidthCoarse;
+ break;
+ case core::BuiltinFn::kFwidthFine:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpFwidthFine;
+ break;
+ case core::BuiltinFn::kInsertBits:
+ op = spv::Op::OpBitFieldInsert;
+ break;
+ case core::BuiltinFn::kInverseSqrt:
+ glsl_ext_inst(GLSLstd450InverseSqrt);
+ break;
+ case core::BuiltinFn::kLdexp:
+ glsl_ext_inst(GLSLstd450Ldexp);
+ break;
+ case core::BuiltinFn::kLength:
+ glsl_ext_inst(GLSLstd450Length);
+ break;
+ case core::BuiltinFn::kLog:
+ glsl_ext_inst(GLSLstd450Log);
+ break;
+ case core::BuiltinFn::kLog2:
+ glsl_ext_inst(GLSLstd450Log2);
+ break;
+ case core::BuiltinFn::kMax:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450FMax);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SMax);
+ } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450UMax);
+ }
+ break;
+ case core::BuiltinFn::kMin:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450FMin);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SMin);
+ } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450UMin);
+ }
+ break;
+ case core::BuiltinFn::kMix:
+ glsl_ext_inst(GLSLstd450FMix);
+ break;
+ case core::BuiltinFn::kModf:
+ glsl_ext_inst(GLSLstd450ModfStruct);
+ break;
+ case core::BuiltinFn::kNormalize:
+ glsl_ext_inst(GLSLstd450Normalize);
+ break;
+ case core::BuiltinFn::kPack2X16Float:
+ glsl_ext_inst(GLSLstd450PackHalf2x16);
+ break;
+ case core::BuiltinFn::kPack2X16Snorm:
+ glsl_ext_inst(GLSLstd450PackSnorm2x16);
+ break;
+ case core::BuiltinFn::kPack2X16Unorm:
+ glsl_ext_inst(GLSLstd450PackUnorm2x16);
+ break;
+ case core::BuiltinFn::kPack4X8Snorm:
+ glsl_ext_inst(GLSLstd450PackSnorm4x8);
+ break;
+ case core::BuiltinFn::kPack4X8Unorm:
+ glsl_ext_inst(GLSLstd450PackUnorm4x8);
+ break;
+ case core::BuiltinFn::kPow:
+ glsl_ext_inst(GLSLstd450Pow);
+ break;
+ case core::BuiltinFn::kQuantizeToF16:
+ op = spv::Op::OpQuantizeToF16;
+ break;
+ case core::BuiltinFn::kRadians:
+ glsl_ext_inst(GLSLstd450Radians);
+ break;
+ case core::BuiltinFn::kReflect:
+ glsl_ext_inst(GLSLstd450Reflect);
+ break;
+ case core::BuiltinFn::kRefract:
+ glsl_ext_inst(GLSLstd450Refract);
+ break;
+ case core::BuiltinFn::kReverseBits:
+ op = spv::Op::OpBitReverse;
+ break;
+ case core::BuiltinFn::kRound:
+ glsl_ext_inst(GLSLstd450RoundEven);
+ break;
+ case core::BuiltinFn::kSign:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450FSign);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SSign);
+ }
+ break;
+ case core::BuiltinFn::kSin:
+ glsl_ext_inst(GLSLstd450Sin);
+ break;
+ case core::BuiltinFn::kSinh:
+ glsl_ext_inst(GLSLstd450Sinh);
+ break;
+ case core::BuiltinFn::kSmoothstep:
+ glsl_ext_inst(GLSLstd450SmoothStep);
+ break;
+ case core::BuiltinFn::kSqrt:
+ glsl_ext_inst(GLSLstd450Sqrt);
+ break;
+ case core::BuiltinFn::kStep:
+ glsl_ext_inst(GLSLstd450Step);
+ break;
+ case core::BuiltinFn::kStorageBarrier:
+ op = spv::Op::OpControlBarrier;
+ operands.clear();
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(
+ Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::UniformMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
+ break;
+ case core::BuiltinFn::kSubgroupBallot:
+ module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
+ op = spv::Op::OpGroupNonUniformBallot;
+ operands.push_back(Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup))));
+ operands.push_back(Constant(ir_.constant_values.Get(true)));
+ break;
+ case core::BuiltinFn::kSubgroupBroadcast:
+ module_.PushCapability(SpvCapabilityGroupNonUniformBallot);
+ op = spv::Op::OpGroupNonUniformBroadcast;
+ operands.push_back(Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup))));
+ break;
+ case core::BuiltinFn::kTan:
+ glsl_ext_inst(GLSLstd450Tan);
+ break;
+ case core::BuiltinFn::kTanh:
+ glsl_ext_inst(GLSLstd450Tanh);
+ break;
+ case core::BuiltinFn::kTextureBarrier:
+ op = spv::Op::OpControlBarrier;
+ operands.clear();
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(
+ Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::ImageMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
+ break;
+ case core::BuiltinFn::kTextureNumLevels:
+ module_.PushCapability(SpvCapabilityImageQuery);
+ op = spv::Op::OpImageQueryLevels;
+ break;
+ case core::BuiltinFn::kTextureNumSamples:
+ module_.PushCapability(SpvCapabilityImageQuery);
+ op = spv::Op::OpImageQuerySamples;
+ break;
+ case core::BuiltinFn::kTranspose:
+ op = spv::Op::OpTranspose;
+ break;
+ case core::BuiltinFn::kTrunc:
+ glsl_ext_inst(GLSLstd450Trunc);
+ break;
+ case core::BuiltinFn::kUnpack2X16Float:
+ glsl_ext_inst(GLSLstd450UnpackHalf2x16);
+ break;
+ case core::BuiltinFn::kUnpack2X16Snorm:
+ glsl_ext_inst(GLSLstd450UnpackSnorm2x16);
+ break;
+ case core::BuiltinFn::kUnpack2X16Unorm:
+ glsl_ext_inst(GLSLstd450UnpackUnorm2x16);
+ break;
+ case core::BuiltinFn::kUnpack4X8Snorm:
+ glsl_ext_inst(GLSLstd450UnpackSnorm4x8);
+ break;
+ case core::BuiltinFn::kUnpack4X8Unorm:
+ glsl_ext_inst(GLSLstd450UnpackUnorm4x8);
+ break;
+ case core::BuiltinFn::kWorkgroupBarrier:
+ op = spv::Op::OpControlBarrier;
+ operands.clear();
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(Constant(b_.ConstantValue(u32(spv::Scope::Workgroup))));
+ operands.push_back(
+ Constant(b_.ConstantValue(u32(spv::MemorySemanticsMask::WorkgroupMemory |
+ spv::MemorySemanticsMask::AcquireRelease))));
+ break;
+ default:
+ TINT_ICE() << "unimplemented builtin function: " << builtin->Func();
+ }
+ TINT_ASSERT(op != spv::Op::Max);
+
+ // Add the arguments to the builtin call.
+ for (auto* arg : builtin->Args()) {
+ operands.push_back(Value(arg));
+ }
+
+ // Emit the instruction.
+ current_function_.push_inst(op, operands);
+ }
+
+ /// Emit a construct instruction.
+ /// @param construct the construct instruction to emit
+ void EmitConstruct(core::ir::Construct* construct) {
+ // If there is just a single argument with the same type as the result, this is an identity
+ // constructor and we can just pass through the ID of the argument.
+ if (construct->Args().Length() == 1 &&
+ construct->Result()->Type() == construct->Args()[0]->Type()) {
+ values_.Add(construct->Result(), Value(construct->Args()[0]));
+ return;
+ }
+
+ OperandList operands = {Type(construct->Result()->Type()), Value(construct)};
+ for (auto* arg : construct->Args()) {
+ operands.push_back(Value(arg));
+ }
+ current_function_.push_inst(spv::Op::OpCompositeConstruct, std::move(operands));
+ }
+
+ /// Emit a convert instruction.
+ /// @param convert the convert instruction to emit
+ void EmitConvert(core::ir::Convert* convert) {
+ auto* res_ty = convert->Result()->Type();
+ auto* arg_ty = convert->Args()[0]->Type();
+
+ OperandList operands = {Type(convert->Result()->Type()), Value(convert)};
+ for (auto* arg : convert->Args()) {
+ operands.push_back(Value(arg));
+ }
+
+ spv::Op op = spv::Op::Max;
+ if (res_ty->is_signed_integer_scalar_or_vector() && arg_ty->is_float_scalar_or_vector()) {
+ // float to signed int.
+ op = spv::Op::OpConvertFToS;
+ } else if (res_ty->is_unsigned_integer_scalar_or_vector() &&
+ arg_ty->is_float_scalar_or_vector()) {
+ // float to unsigned int.
+ op = spv::Op::OpConvertFToU;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_signed_integer_scalar_or_vector()) {
+ // signed int to float.
+ op = spv::Op::OpConvertSToF;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_unsigned_integer_scalar_or_vector()) {
+ // unsigned int to float.
+ op = spv::Op::OpConvertUToF;
+ } else if (res_ty->is_float_scalar_or_vector() && arg_ty->is_float_scalar_or_vector() &&
+ res_ty->Size() != arg_ty->Size()) {
+ // float to float (different bitwidth).
+ op = spv::Op::OpFConvert;
+ } else if (res_ty->is_integer_scalar_or_vector() && arg_ty->is_integer_scalar_or_vector() &&
+ res_ty->Size() == arg_ty->Size()) {
+ // int to int (same bitwidth, different signedness).
+ op = spv::Op::OpBitcast;
+ } else if (res_ty->is_bool_scalar_or_vector()) {
+ if (arg_ty->is_integer_scalar_or_vector()) {
+ // int to bool.
+ op = spv::Op::OpINotEqual;
+ } else {
+ // float to bool.
+ op = spv::Op::OpFUnordNotEqual;
+ }
+ operands.push_back(ConstantNull(arg_ty));
+ } else if (arg_ty->is_bool_scalar_or_vector()) {
+ // Select between constant one and zero, splatting them to vectors if necessary.
+ core::ir::Constant* one = nullptr;
+ core::ir::Constant* zero = nullptr;
+ Switch(
+ res_ty->DeepestElement(), //
+ [&](const core::type::F32*) {
+ one = b_.Constant(1_f);
+ zero = b_.Constant(0_f);
+ },
+ [&](const core::type::F16*) {
+ one = b_.Constant(1_h);
+ zero = b_.Constant(0_h);
+ },
+ [&](const core::type::I32*) {
+ one = b_.Constant(1_i);
+ zero = b_.Constant(0_i);
+ },
+ [&](const core::type::U32*) {
+ one = b_.Constant(1_u);
+ zero = b_.Constant(0_u);
+ });
+ TINT_ASSERT_OR_RETURN(one && zero);
+
+ if (auto* vec = res_ty->As<core::type::Vector>()) {
+ // Splat the scalars into vectors.
+ one = b_.Splat(vec, one, vec->Width());
+ zero = b_.Splat(vec, zero, vec->Width());
+ }
+
+ op = spv::Op::OpSelect;
+ operands.push_back(Constant(b_.ConstantValue(one)));
+ operands.push_back(Constant(b_.ConstantValue(zero)));
+ } else {
+ TINT_ICE() << "unhandled convert instruction";
+ }
+
+ current_function_.push_inst(op, std::move(operands));
+ }
+
+ /// Emit a load instruction.
+ /// @param load the load instruction to emit
+ void EmitLoad(core::ir::Load* load) {
+ current_function_.push_inst(
+ spv::Op::OpLoad, {Type(load->Result()->Type()), Value(load), Value(load->From())});
+ }
+
+ /// Emit a load vector element instruction.
+ /// @param load the load vector element instruction to emit
+ void EmitLoadVectorElement(core::ir::LoadVectorElement* load) {
+ auto* vec_ptr_ty = load->From()->Type()->As<core::type::Pointer>();
+ auto* el_ty = load->Result()->Type();
+ auto* el_ptr_ty = ir_.Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
+ auto el_ptr_id = module_.NextId();
+ current_function_.push_inst(
+ spv::Op::OpAccessChain,
+ {Type(el_ptr_ty), el_ptr_id, Value(load->From()), Value(load->Index())});
+ current_function_.push_inst(spv::Op::OpLoad,
+ {Type(load->Result()->Type()), Value(load), el_ptr_id});
+ }
+
+ /// Emit a loop instruction.
+ /// @param loop the loop instruction to emit
+ void EmitLoop(core::ir::Loop* loop) {
+ auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
+ auto body_label = Label(loop->Body());
+ auto continuing_label = Label(loop->Continuing());
+
+ auto header_label = module_.NextId();
+ TINT_SCOPED_ASSIGNMENT(loop_header_label_, header_label);
+
+ auto merge_label = GetMergeLabel(loop);
+ TINT_SCOPED_ASSIGNMENT(loop_merge_label_, merge_label);
+
+ if (init_label != 0) {
+ // Emit the loop initializer.
+ current_function_.push_inst(spv::Op::OpBranch, {init_label});
+ EmitBlock(loop->Initializer());
+ } else {
+ // No initializer. Branch to body.
+ current_function_.push_inst(spv::Op::OpBranch, {header_label});
+ }
+
+ // Emit the loop body header, which contains the OpLoopMerge and OpPhis.
+ // This then unconditionally branches to body_label
+ current_function_.push_inst(spv::Op::OpLabel, {header_label});
+ EmitIncomingPhis(loop->Body());
+ current_function_.push_inst(spv::Op::OpLoopMerge, {merge_label, continuing_label,
+ U32Operand(SpvLoopControlMaskNone)});
+ current_function_.push_inst(spv::Op::OpBranch, {body_label});
+
+ // Emit the loop body
+ current_function_.push_inst(spv::Op::OpLabel, {body_label});
+ EmitBlockInstructions(loop->Body());
+
+ // Emit the loop continuing block.
+ if (loop->Continuing()->HasTerminator()) {
+ EmitBlock(loop->Continuing());
+ } else {
+ // We still need to emit a continuing block with a back-edge, even if it is unreachable.
+ current_function_.push_inst(spv::Op::OpLabel, {continuing_label});
+ current_function_.push_inst(spv::Op::OpBranch, {header_label});
+ }
+
+ // Emit the loop merge block.
+ current_function_.push_inst(spv::Op::OpLabel, {merge_label});
+
+ // Emit the OpPhis for the ExitLoops
+ EmitExitPhis(loop);
+ }
+
+ /// Emit a switch instruction.
+ /// @param swtch the switch instruction to emit
+ void EmitSwitch(core::ir::Switch* swtch) {
+ // Find the default selector. There must be exactly one.
+ uint32_t default_label = 0u;
+ for (auto& c : swtch->Cases()) {
+ for (auto& sel : c.selectors) {
+ if (sel.IsDefault()) {
+ default_label = Label(c.Block());
+ }
+ }
+ }
+ TINT_ASSERT(default_label != 0u);
+
+ // Build the operands to the OpSwitch instruction.
+ OperandList switch_operands = {Value(swtch->Condition()), default_label};
+ for (auto& c : swtch->Cases()) {
+ auto label = Label(c.Block());
+ for (auto& sel : c.selectors) {
+ if (sel.IsDefault()) {
+ continue;
+ }
+ switch_operands.push_back(sel.val->Value()->ValueAs<uint32_t>());
+ switch_operands.push_back(label);
+ }
+ }
+
+ uint32_t merge_label = GetMergeLabel(swtch);
+ TINT_SCOPED_ASSIGNMENT(switch_merge_label_, merge_label);
+
+ // Emit the OpSelectionMerge and OpSwitch instructions.
+ current_function_.push_inst(spv::Op::OpSelectionMerge,
+ {merge_label, U32Operand(SpvSelectionControlMaskNone)});
+ current_function_.push_inst(spv::Op::OpSwitch, switch_operands);
+
+ // Emit the cases.
+ for (auto& c : swtch->Cases()) {
+ EmitBlock(c.Block());
+ }
+
+ // Emit the switch merge block.
+ current_function_.push_inst(spv::Op::OpLabel, {merge_label});
+
+ // Emit the OpPhis for the ExitSwitches
+ EmitExitPhis(swtch);
+ }
+
+ /// Emit a swizzle instruction.
+ /// @param swizzle the swizzle instruction to emit
+ void EmitSwizzle(core::ir::Swizzle* swizzle) {
+ auto id = Value(swizzle);
+ auto obj = Value(swizzle->Object());
+ OperandList operands = {Type(swizzle->Result()->Type()), id, obj, obj};
+ for (auto idx : swizzle->Indices()) {
+ operands.push_back(idx);
+ }
+ current_function_.push_inst(spv::Op::OpVectorShuffle, operands);
+ }
+
+ /// Emit a store instruction.
+ /// @param store the store instruction to emit
+ void EmitStore(core::ir::Store* store) {
+ current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
+ }
+
+ /// Emit a store vector element instruction.
+ /// @param store the store vector element instruction to emit
+ void EmitStoreVectorElement(core::ir::StoreVectorElement* store) {
+ auto* vec_ptr_ty = store->To()->Type()->As<core::type::Pointer>();
+ auto* el_ty = store->Value()->Type();
+ auto* el_ptr_ty = ir_.Types().ptr(vec_ptr_ty->AddressSpace(), el_ty, vec_ptr_ty->Access());
+ auto el_ptr_id = module_.NextId();
+ current_function_.push_inst(
+ spv::Op::OpAccessChain,
+ {Type(el_ptr_ty), el_ptr_id, Value(store->To()), Value(store->Index())});
+ current_function_.push_inst(spv::Op::OpStore, {el_ptr_id, Value(store->Value())});
+ }
+
+ /// Emit a unary instruction.
+ /// @param unary the unary instruction to emit
+ void EmitUnary(core::ir::Unary* unary) {
+ auto id = Value(unary);
+ auto* ty = unary->Result()->Type();
+ spv::Op op = spv::Op::Max;
+ switch (unary->Op()) {
+ case core::ir::UnaryOp::kComplement:
+ op = spv::Op::OpNot;
+ break;
+ case core::ir::UnaryOp::kNegation:
+ if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFNegate;
+ } else if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSNegate;
+ }
+ break;
+ }
+ current_function_.push_inst(op, {Type(ty), id, Value(unary->Val())});
+ }
+
+ /// Emit a user call instruction.
+ /// @param call the user call instruction to emit
+ void EmitUserCall(core::ir::UserCall* call) {
+ auto id = Value(call);
+ OperandList operands = {Type(call->Result()->Type()), id, Value(call->Target())};
+ for (auto* arg : call->Args()) {
+ operands.push_back(Value(arg));
+ }
+ current_function_.push_inst(spv::Op::OpFunctionCall, operands);
+ }
+
+ /// Emit IO attributes.
+ /// @param id the ID of the variable to decorate
+ /// @param attrs the shader IO attrs
+ /// @param addrspace the address of the variable
+ void EmitIOAttributes(uint32_t id,
+ const core::ir::IOAttributes& attrs,
+ core::AddressSpace addrspace) {
+ if (attrs.location) {
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationLocation), *attrs.location});
+ }
+ if (attrs.index) {
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationIndex), *attrs.index});
+ }
+ if (attrs.interpolation) {
+ switch (attrs.interpolation->type) {
+ case core::InterpolationType::kLinear:
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationNoPerspective)});
+ break;
+ case core::InterpolationType::kFlat:
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationFlat)});
+ break;
+ case core::InterpolationType::kPerspective:
+ case core::InterpolationType::kUndefined:
+ break;
+ }
+ switch (attrs.interpolation->sampling) {
+ case core::InterpolationSampling::kCentroid:
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationCentroid)});
+ break;
+ case core::InterpolationSampling::kSample:
+ module_.PushCapability(SpvCapabilitySampleRateShading);
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationSample)});
+ break;
+ case core::InterpolationSampling::kCenter:
+ case core::InterpolationSampling::kUndefined:
+ break;
+ }
+ }
+ if (attrs.builtin) {
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBuiltIn),
+ Builtin(*attrs.builtin, addrspace)});
+ }
+ if (attrs.invariant) {
+ module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationInvariant)});
+ }
+ }
+
+ /// Emit a var instruction.
+ /// @param var the var instruction to emit
+ void EmitVar(core::ir::Var* var) {
+ auto id = Value(var);
+ auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
+ auto* store_ty = ptr->StoreType();
+ auto ty = Type(ptr);
+
+ switch (ptr->AddressSpace()) {
+ case core::AddressSpace::kFunction: {
+ TINT_ASSERT(current_function_);
+ if (var->Initializer()) {
+ current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)});
+ current_function_.push_inst(spv::Op::OpStore, {id, Value(var->Initializer())});
+ } else {
+ current_function_.push_var(
+ {ty, id, U32Operand(SpvStorageClassFunction), ConstantNull(store_ty)});
+ }
+ break;
+ }
+ case core::AddressSpace::kIn: {
+ TINT_ASSERT(!current_function_);
+ module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassInput)});
+ EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kIn);
+ break;
+ }
+ case core::AddressSpace::kPrivate: {
+ TINT_ASSERT(!current_function_);
+ OperandList operands = {ty, id, U32Operand(SpvStorageClassPrivate)};
+ if (var->Initializer()) {
+ TINT_ASSERT(var->Initializer()->Is<core::ir::Constant>());
+ operands.push_back(Value(var->Initializer()));
+ } else {
+ operands.push_back(ConstantNull(store_ty));
+ }
+ module_.PushType(spv::Op::OpVariable, operands);
+ break;
+ }
+ case core::AddressSpace::kPushConstant: {
+ TINT_ASSERT(!current_function_);
+ module_.PushType(spv::Op::OpVariable,
+ {ty, id, U32Operand(SpvStorageClassPushConstant)});
+ break;
+ }
+ case core::AddressSpace::kOut: {
+ TINT_ASSERT(!current_function_);
+ module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassOutput)});
+ EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kOut);
+ break;
+ }
+ case core::AddressSpace::kHandle:
+ case core::AddressSpace::kStorage:
+ case core::AddressSpace::kUniform: {
+ TINT_ASSERT(!current_function_);
+ module_.PushType(spv::Op::OpVariable,
+ {ty, id, U32Operand(StorageClass(ptr->AddressSpace()))});
+ auto bp = var->BindingPoint().value();
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationDescriptorSet), bp.group});
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationBinding), bp.binding});
+
+ // Add NonReadable and NonWritable decorations to storage textures and buffers.
+ auto* st = store_ty->As<core::type::StorageTexture>();
+ if (st || store_ty->Is<core::type::Struct>()) {
+ auto access = st ? st->access() : ptr->Access();
+ if (access == core::Access::kRead) {
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationNonWritable)});
+ } else if (access == core::Access::kWrite) {
+ module_.PushAnnot(spv::Op::OpDecorate,
+ {id, U32Operand(SpvDecorationNonReadable)});
+ }
+ }
+ break;
+ }
+ case core::AddressSpace::kWorkgroup: {
+ TINT_ASSERT(!current_function_);
+ OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
+ if (zero_init_workgroup_memory_) {
+ // If requested, use the VK_KHR_zero_initialize_workgroup_memory to
+ // zero-initialize the workgroup variable using an null constant initializer.
+ operands.push_back(ConstantNull(store_ty));
+ }
+ module_.PushType(spv::Op::OpVariable, operands);
+ break;
+ }
+ default: {
+ TINT_ICE() << "unimplemented variable address space " << ptr->AddressSpace();
+ }
+ }
+
+ // Set the name if present.
+ if (auto name = ir_.NameOf(var)) {
+ module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
+ }
+ }
+
+ /// Emit a let instruction.
+ /// @param let the let instruction to emit
+ void EmitLet(core::ir::Let* let) {
+ auto id = Value(let->Value());
+ values_.Add(let->Result(), id);
+ }
+
+ /// Emit the OpPhis for the given flow control instruction.
+ /// @param inst the flow control instruction
+ void EmitExitPhis(core::ir::ControlInstruction* inst) {
+ struct Branch {
+ uint32_t label = 0;
+ core::ir::Value* value = nullptr;
+ bool operator<(const Branch& other) const { return label < other.label; }
+ };
+
+ auto results = inst->Results();
+ for (size_t index = 0; index < results.Length(); index++) {
+ auto* result = results[index];
+ auto* ty = result->Type();
+
+ Vector<Branch, 8> branches;
+ branches.Reserve(inst->Exits().Count());
+ for (auto& exit : inst->Exits()) {
+ branches.Push(Branch{GetTerminatorBlockLabel(exit), exit->Args()[index]});
+ }
+ branches.Sort(); // Sort the branches by label to ensure deterministic output
+
+ OperandList ops{Type(ty), Value(result)};
+ for (auto& branch : branches) {
+ if (branch.value == nullptr) {
+ ops.push_back(Undef(ty));
+ } else {
+ ops.push_back(Value(branch.value));
+ }
+ ops.push_back(branch.label);
+ }
+ current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
+ }
+ }
+
+ /// Get the ID of the label of the merge block for a control instruction.
+ /// @param ci the control instruction to get the merge label for
+ /// @returns the label ID
+ uint32_t GetMergeLabel(core::ir::ControlInstruction* ci) {
+ return merge_block_labels_.GetOrCreate(ci, [&] { return module_.NextId(); });
+ }
+
+ /// Get the ID of the label of the block that will contain a terminator instruction.
+ /// @param t the terminator instruction to get the block label for
+ /// @returns the label ID
+ uint32_t GetTerminatorBlockLabel(core::ir::Terminator* t) {
+ // Walk backwards from `t` until we find a control instruction.
+ auto* inst = t->prev;
+ while (inst) {
+ auto* prev = inst->prev;
+ if (auto* ci = inst->As<core::ir::ControlInstruction>()) {
+ // This is the last control instruction before `t`, so use its merge block label.
+ return GetMergeLabel(ci);
+ }
+ inst = prev;
+ }
+
+ // There were no control instructions before `t`, so use the label of the parent block.
+ return Label(t->Block());
+ }
+
+ /// Convert a texel format to the corresponding SPIR-V enum value, adding required capabilities.
+ /// @param format the format to convert
+ /// @returns the enum value of the corresponding SPIR-V texel format
+ uint32_t TexelFormat(const core::TexelFormat format) {
+ switch (format) {
+ case core::TexelFormat::kBgra8Unorm:
+ TINT_ICE() << "bgra8unorm should have been polyfilled to rgba8unorm";
+ return SpvImageFormatUnknown;
+ case core::TexelFormat::kR32Uint:
+ return SpvImageFormatR32ui;
+ case core::TexelFormat::kR32Sint:
+ return SpvImageFormatR32i;
+ case core::TexelFormat::kR32Float:
+ return SpvImageFormatR32f;
+ case core::TexelFormat::kRgba8Unorm:
+ return SpvImageFormatRgba8;
+ case core::TexelFormat::kRgba8Snorm:
+ return SpvImageFormatRgba8Snorm;
+ case core::TexelFormat::kRgba8Uint:
+ return SpvImageFormatRgba8ui;
+ case core::TexelFormat::kRgba8Sint:
+ return SpvImageFormatRgba8i;
+ case core::TexelFormat::kRg32Uint:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32ui;
+ case core::TexelFormat::kRg32Sint:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32i;
+ case core::TexelFormat::kRg32Float:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32f;
+ case core::TexelFormat::kRgba16Uint:
+ return SpvImageFormatRgba16ui;
+ case core::TexelFormat::kRgba16Sint:
+ return SpvImageFormatRgba16i;
+ case core::TexelFormat::kRgba16Float:
+ return SpvImageFormatRgba16f;
+ case core::TexelFormat::kRgba32Uint:
+ return SpvImageFormatRgba32ui;
+ case core::TexelFormat::kRgba32Sint:
+ return SpvImageFormatRgba32i;
+ case core::TexelFormat::kRgba32Float:
+ return SpvImageFormatRgba32f;
+ case core::TexelFormat::kUndefined:
+ return SpvImageFormatUnknown;
+ }
+ return SpvImageFormatUnknown;
+ }
+};
+
+} // namespace
+
+tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module,
+ bool zero_init_workgroup_memory) {
+ return Printer{module, zero_init_workgroup_memory}.Code();
}
-uint32_t Printer::TexelFormat(const core::TexelFormat format) {
- switch (format) {
- case core::TexelFormat::kBgra8Unorm:
- TINT_ICE() << "bgra8unorm should have been polyfilled to rgba8unorm";
- return SpvImageFormatUnknown;
- case core::TexelFormat::kR32Uint:
- return SpvImageFormatR32ui;
- case core::TexelFormat::kR32Sint:
- return SpvImageFormatR32i;
- case core::TexelFormat::kR32Float:
- return SpvImageFormatR32f;
- case core::TexelFormat::kRgba8Unorm:
- return SpvImageFormatRgba8;
- case core::TexelFormat::kRgba8Snorm:
- return SpvImageFormatRgba8Snorm;
- case core::TexelFormat::kRgba8Uint:
- return SpvImageFormatRgba8ui;
- case core::TexelFormat::kRgba8Sint:
- return SpvImageFormatRgba8i;
- case core::TexelFormat::kRg32Uint:
- module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
- return SpvImageFormatRg32ui;
- case core::TexelFormat::kRg32Sint:
- module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
- return SpvImageFormatRg32i;
- case core::TexelFormat::kRg32Float:
- module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
- return SpvImageFormatRg32f;
- case core::TexelFormat::kRgba16Uint:
- return SpvImageFormatRgba16ui;
- case core::TexelFormat::kRgba16Sint:
- return SpvImageFormatRgba16i;
- case core::TexelFormat::kRgba16Float:
- return SpvImageFormatRgba16f;
- case core::TexelFormat::kRgba32Uint:
- return SpvImageFormatRgba32ui;
- case core::TexelFormat::kRgba32Sint:
- return SpvImageFormatRgba32i;
- case core::TexelFormat::kRgba32Float:
- return SpvImageFormatRgba32f;
- case core::TexelFormat::kUndefined:
- return SpvImageFormatUnknown;
- }
- return SpvImageFormatUnknown;
+tint::Result<Module> PrintModule(core::ir::Module& module, bool zero_init_workgroup_memory) {
+ return Printer{module, zero_init_workgroup_memory}.Module();
}
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.h b/src/tint/lang/spirv/writer/printer/printer.h
index 2ce0544..e82b01c 100644
--- a/src/tint/lang/spirv/writer/printer/printer.h
+++ b/src/tint/lang/spirv/writer/printer/printer.h
@@ -28,344 +28,31 @@
#ifndef SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_
#define SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_
-#include <string>
+#include <cstdint>
#include <vector>
-#include "src/tint/lang/core/address_space.h"
-#include "src/tint/lang/core/builtin_value.h"
-#include "src/tint/lang/core/constant/value.h"
-#include "src/tint/lang/core/ir/builder.h"
-#include "src/tint/lang/core/ir/constant.h"
-#include "src/tint/lang/core/texel_format.h"
-#include "src/tint/lang/spirv/ir/builtin_call.h"
-#include "src/tint/lang/spirv/writer/common/binary_writer.h"
-#include "src/tint/lang/spirv/writer/common/function.h"
#include "src/tint/lang/spirv/writer/common/module.h"
-#include "src/tint/utils/containers/hashmap.h"
-#include "src/tint/utils/containers/vector.h"
-#include "src/tint/utils/diagnostic/diagnostic.h"
#include "src/tint/utils/result/result.h"
-#include "src/tint/utils/symbol/symbol.h"
// Forward declarations
namespace tint::core::ir {
-class Access;
-class Binary;
-class Bitcast;
-class Block;
-class BlockParam;
-class Construct;
-class ControlInstruction;
-class Convert;
-class CoreBuiltinCall;
-class ExitIf;
-class ExitLoop;
-class ExitSwitch;
-class Function;
-class If;
-class Let;
-class Load;
-class LoadVectorElement;
-class Loop;
class Module;
-class MultiInBlock;
-class Store;
-class StoreVectorElement;
-class Switch;
-class Swizzle;
-class Terminator;
-class Unary;
-class UserCall;
-class Value;
-class Var;
} // namespace tint::core::ir
-namespace tint::core::type {
-class Struct;
-class Texture;
-class Type;
-} // namespace tint::core::type
namespace tint::spirv::writer {
-/// Implementation class for SPIR-V writer
-class Printer {
- public:
- /// Constructor
- /// @param module the Tint IR module to generate
- /// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
- /// storage class with OpConstantNull
- Printer(core::ir::Module& module, bool zero_init_workgroup_memory);
+/// @returns the generated SPIR-V instructions on success, or failure
+/// @param module the Tint IR module to generate
+/// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
+/// storage class with OpConstantNull
+tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module,
+ bool zero_init_workgroup_memory);
- /// @returns the generated SPIR-V binary on success, or failure
- tint::Result<std::vector<uint32_t>> Generate();
-
- /// @returns the module that this writer has produced
- writer::Module& Module() { return module_; }
-
- /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
- /// @param constant the constant to get the ID for
- /// @returns the result ID of the constant
- uint32_t Constant(core::ir::Constant* constant);
-
- /// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
- /// @param ty the type to get the ID for
- /// @returns the result ID of the type
- uint32_t Type(const core::type::Type* ty);
-
- private:
- /// Convert a builtin to the corresponding SPIR-V enum value, taking into account the target
- /// address space. Adds any capabilities needed for the builtin.
- /// @param builtin the builtin to convert
- /// @param addrspace the address space the builtin is being used in
- /// @returns the enum value of the corresponding SPIR-V builtin
- uint32_t Builtin(core::BuiltinValue builtin, core::AddressSpace addrspace);
-
- /// Convert a texel format to the corresponding SPIR-V enum value, adding required capabilities.
- /// @param format the format to convert
- /// @returns the enum value of the corresponding SPIR-V texel format
- uint32_t TexelFormat(const core::TexelFormat format);
-
- /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
- /// @param constant the constant to get the ID for
- /// @returns the result ID of the constant
- uint32_t Constant(const core::constant::Value* constant);
-
- /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
- /// @param type the type to get the ID for
- /// @returns the result ID of the OpConstantNull instruction
- uint32_t ConstantNull(const core::type::Type* type);
-
- /// Get the ID of the label for `block`.
- /// @param block the block to get the label ID for
- /// @returns the ID of the block's label
- uint32_t Label(core::ir::Block* block);
-
- /// Get the result ID of the value `value`, emitting its instruction if necessary.
- /// @param value the value to get the ID for
- /// @returns the result ID of the value
- uint32_t Value(core::ir::Value* value);
-
- /// Get the result ID of the instruction result `value`, emitting its instruction if necessary.
- /// @param inst the instruction to get the ID for
- /// @returns the result ID of the instruction
- uint32_t Value(core::ir::Instruction* inst);
-
- /// Get the result ID of the OpUndef instruction with type `ty`, emitting it if necessary.
- /// @param ty the type of the undef value
- /// @returns the result ID of the instruction
- uint32_t Undef(const core::type::Type* ty);
-
- /// Emit a struct type.
- /// @param id the result ID to use
- /// @param str the struct type to emit
- void EmitStructType(uint32_t id, const core::type::Struct* str);
-
- /// Emit a texture type.
- /// @param id the result ID to use
- /// @param texture the texture type to emit
- void EmitTextureType(uint32_t id, const core::type::Texture* texture);
-
- /// Emit a function.
- /// @param func the function to emit
- void EmitFunction(core::ir::Function* func);
-
- /// Emit entry point declarations for a function.
- /// @param func the function to emit entry point declarations for
- /// @param id the result ID of the function declaration
- void EmitEntryPoint(core::ir::Function* func, uint32_t id);
-
- /// Emit a block, including the initial OpLabel, OpPhis and instructions.
- /// @param block the block to emit
- void EmitBlock(core::ir::Block* block);
-
- /// Emit all OpPhi nodes for incoming branches to @p block.
- /// @param block the block to emit the OpPhis for
- void EmitIncomingPhis(core::ir::MultiInBlock* block);
-
- /// Emit all instructions of @p block.
- /// @param block the block's instructions to emit
- void EmitBlockInstructions(core::ir::Block* block);
-
- /// Emit the root block.
- /// @param root_block the root block to emit
- void EmitRootBlock(core::ir::Block* root_block);
-
- /// Emit an `if` flow node.
- /// @param i the if node to emit
- void EmitIf(core::ir::If* i);
-
- /// Emit an access instruction
- /// @param access the access instruction to emit
- void EmitAccess(core::ir::Access* access);
-
- /// Emit a binary instruction.
- /// @param binary the binary instruction to emit
- void EmitBinary(core::ir::Binary* binary);
-
- /// Emit a bitcast instruction.
- /// @param bitcast the bitcast instruction to emit
- void EmitBitcast(core::ir::Bitcast* bitcast);
-
- /// Emit a builtin function call instruction.
- /// @param call the builtin call instruction to emit
- void EmitSpirvBuiltinCall(spirv::ir::BuiltinCall* call);
-
- /// Emit a builtin function call instruction.
- /// @param call the builtin call instruction to emit
- void EmitCoreBuiltinCall(core::ir::CoreBuiltinCall* call);
-
- /// Emit a construct instruction.
- /// @param construct the construct instruction to emit
- void EmitConstruct(core::ir::Construct* construct);
-
- /// Emit a convert instruction.
- /// @param convert the convert instruction to emit
- void EmitConvert(core::ir::Convert* convert);
-
- /// Emit IO attributes.
- /// @param id the ID of the variable to decorate
- /// @param attrs the shader IO attrs
- /// @param addrspace the address of the variable
- void EmitIOAttributes(uint32_t id,
- const core::ir::IOAttributes& attrs,
- core::AddressSpace addrspace);
-
- /// Emit a load instruction.
- /// @param load the load instruction to emit
- void EmitLoad(core::ir::Load* load);
-
- /// Emit a load vector element instruction.
- /// @param load the load vector element instruction to emit
- void EmitLoadVectorElement(core::ir::LoadVectorElement* load);
-
- /// Emit a loop instruction.
- /// @param loop the loop instruction to emit
- void EmitLoop(core::ir::Loop* loop);
-
- /// Emit a store instruction.
- /// @param store the store instruction to emit
- void EmitStore(core::ir::Store* store);
-
- /// Emit a store vector element instruction.
- /// @param store the store vector element instruction to emit
- void EmitStoreVectorElement(core::ir::StoreVectorElement* store);
-
- /// Emit a switch instruction.
- /// @param swtch the switch instruction to emit
- void EmitSwitch(core::ir::Switch* swtch);
-
- /// Emit a swizzle instruction.
- /// @param swizzle the swizzle instruction to emit
- void EmitSwizzle(core::ir::Swizzle* swizzle);
-
- /// Emit a unary instruction.
- /// @param unary the unary instruction to emit
- void EmitUnary(core::ir::Unary* unary);
-
- /// Emit a user call instruction.
- /// @param call the user call instruction to emit
- void EmitUserCall(core::ir::UserCall* call);
-
- /// Emit a var instruction.
- /// @param var the var instruction to emit
- void EmitVar(core::ir::Var* var);
-
- /// Emit a let instruction.
- /// @param let the let instruction to emit
- void EmitLet(core::ir::Let* let);
-
- /// Emit a terminator instruction.
- /// @param term the terminator instruction to emit
- void EmitTerminator(core::ir::Terminator* term);
-
- /// Emit the OpPhis for the given flow control instruction.
- /// @param inst the flow control instruction
- void EmitExitPhis(core::ir::ControlInstruction* inst);
-
- /// Get the ID of the label of the merge block for a control instruction.
- /// @param ci the control instruction to get the merge label for
- /// @returns the label ID
- uint32_t GetMergeLabel(core::ir::ControlInstruction* ci);
-
- /// Get the ID of the label of the block that will contain a terminator instruction.
- /// @param t the terminator instruction to get the block label for
- /// @returns the label ID
- uint32_t GetTerminatorBlockLabel(core::ir::Terminator* t);
-
- core::ir::Module& ir_;
- core::ir::Builder b_;
- writer::Module module_;
- BinaryWriter writer_;
-
- /// A function type used for an OpTypeFunction declaration.
- struct FunctionType {
- uint32_t return_type_id;
- Vector<uint32_t, 4> param_type_ids;
-
- /// Hasher provides a hash function for the FunctionType.
- struct Hasher {
- /// @param ft the FunctionType to create a hash for
- /// @return the hash value
- inline std::size_t operator()(const FunctionType& ft) const {
- size_t hash = Hash(ft.return_type_id);
- for (auto& p : ft.param_type_ids) {
- hash = HashCombine(hash, p);
- }
- return hash;
- }
- };
-
- /// Equality operator for FunctionType.
- bool operator==(const FunctionType& other) const {
- return (param_type_ids == other.param_type_ids) &&
- (return_type_id == other.return_type_id);
- }
- };
-
- /// The map of types to their result IDs.
- Hashmap<const core::type::Type*, uint32_t, 8> types_;
-
- /// The map of function types to their result IDs.
- Hashmap<FunctionType, uint32_t, 8, FunctionType::Hasher> function_types_;
-
- /// The map of constants to their result IDs.
- Hashmap<const core::constant::Value*, uint32_t, 16> constants_;
-
- /// The map of types to the result IDs of their OpConstantNull instructions.
- Hashmap<const core::type::Type*, uint32_t, 4> constant_nulls_;
-
- /// The map of types to the result IDs of their OpUndef instructions.
- Hashmap<const core::type::Type*, uint32_t, 4> undef_values_;
-
- /// The map of non-constant values to their result IDs.
- Hashmap<core::ir::Value*, uint32_t, 8> values_;
-
- /// The map of blocks to the IDs of their label instructions.
- Hashmap<core::ir::Block*, uint32_t, 8> block_labels_;
-
- /// The map of control instructions to the IDs of the label of their SPIR-V merge blocks.
- Hashmap<core::ir::ControlInstruction*, uint32_t, 8> merge_block_labels_;
-
- /// The map of extended instruction set names to their result IDs.
- Hashmap<std::string_view, uint32_t, 2> imports_;
-
- /// The current function that is being emitted.
- Function current_function_;
-
- /// The merge block for the current if statement
- uint32_t if_merge_label_ = 0;
-
- /// The header block for the current loop statement
- uint32_t loop_header_label_ = 0;
-
- /// The merge block for the current loop statement
- uint32_t loop_merge_label_ = 0;
-
- /// The merge block for the current switch statement
- uint32_t switch_merge_label_ = 0;
-
- bool zero_init_workgroup_memory_ = false;
-};
+/// @returns the generated SPIR-V module on success, or failure
+/// @param module the Tint IR module to generate
+/// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
+/// storage class with OpConstantNull
+tint::Result<Module> PrintModule(core::ir::Module& module, bool zero_init_workgroup_memory);
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/raise/merge_return.cc b/src/tint/lang/spirv/writer/raise/merge_return.cc
index a215fee..dbd06d7 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return.cc
@@ -198,7 +198,7 @@
// Loop over the 'if' instructions, starting with the inner-most, and add any missing
// terminating instructions to the blocks holding the 'if'.
for (auto* i = inner_if; i; i = tint::As<core::ir::If>(i->Block()->Parent())) {
- if (!i->Block()->HasTerminator()) {
+ if (!i->Block()->HasTerminator() && i->Block()->Parent()) {
// Append the exit instruction to the block holding the 'if'.
Vector<core::ir::InstructionResult*, 8> exit_args = i->Results();
if (!i->HasResults()) {
diff --git a/src/tint/lang/spirv/writer/raise/merge_return_test.cc b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
index 2c2ae37..29a5b3f 100644
--- a/src/tint/lang/spirv/writer/raise/merge_return_test.cc
+++ b/src/tint/lang/spirv/writer/raise/merge_return_test.cc
@@ -1180,6 +1180,91 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvWriter_MergeReturnTest, IfElse_Consecutive_ThenUnreachable) {
+ auto* value = b.FunctionParam(ty.i32());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({value});
+
+ b.Append(func->Block(), [&] {
+ {
+ auto* if_ = b.If(b.Equal(ty.bool_(), value, 1_i));
+ b.Append(if_->True(), [&] { b.Return(func, 101_i); });
+ }
+ {
+ auto* ifelse = b.If(b.Equal(ty.bool_(), value, 2_i));
+ b.Append(ifelse->True(), [&] { b.Return(func, 202_i); });
+ b.Append(ifelse->False(), [&] { b.Return(func, 303_i); });
+ }
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ %3:bool = eq %2, 1i
+ if %3 [t: %b2] { # if_1
+ %b2 = block { # true
+ ret 101i
+ }
+ }
+ %4:bool = eq %2, 2i
+ if %4 [t: %b3, f: %b4] { # if_2
+ %b3 = block { # true
+ ret 202i
+ }
+ %b4 = block { # false
+ ret 303i
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%2:i32):i32 -> %b1 {
+ %b1 = block {
+ %return_value:ptr<function, i32, read_write> = var
+ %continue_execution:ptr<function, bool, read_write> = var, true
+ %5:bool = eq %2, 1i
+ if %5 [t: %b2] { # if_1
+ %b2 = block { # true
+ store %continue_execution, false
+ store %return_value, 101i
+ exit_if # if_1
+ }
+ }
+ %6:bool = load %continue_execution
+ if %6 [t: %b3] { # if_2
+ %b3 = block { # true
+ %7:bool = eq %2, 2i
+ if %7 [t: %b4, f: %b5] { # if_3
+ %b4 = block { # true
+ store %continue_execution, false
+ store %return_value, 202i
+ exit_if # if_3
+ }
+ %b5 = block { # false
+ store %continue_execution, false
+ store %return_value, 303i
+ exit_if # if_3
+ }
+ }
+ exit_if # if_2
+ }
+ }
+ %8:i32 = load %return_value
+ ret %8
+ }
+}
+)";
+
+ Run(MergeReturn);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(SpirvWriter_MergeReturnTest, Loop_UnconditionalReturnInBody) {
auto* func = b.Function("foo", ty.i32());
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 5c6a9d7..a5b0bc4 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -43,6 +43,7 @@
#include "src/tint/lang/core/ir/transform/preserve_padding.h"
#include "src/tint/lang/core/ir/transform/robustness.h"
#include "src/tint/lang/core/ir/transform/std140.h"
+#include "src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h"
#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
#include "src/tint/lang/spirv/writer/common/option_builder.h"
#include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
@@ -125,6 +126,7 @@
RUN_TRANSFORM(core::ir::transform::AddEmptyEntryPoint, module);
RUN_TRANSFORM(core::ir::transform::Bgra8UnormPolyfill, module);
RUN_TRANSFORM(core::ir::transform::BlockDecoratedStructs, module);
+ RUN_TRANSFORM(core::ir::transform::VectorizeScalarMatrixConstructors, module);
// CombineAccessInstructions must come after DirectVariableAccess and BlockDecoratedStructs.
// We run this transform as some Qualcomm drivers struggle with partial access chains that
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index 322fcd6..08fd086 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -46,42 +46,53 @@
namespace {
TEST_F(SpirvWriterTest, Type_Void) {
- writer_.Type(ty.void_());
+ auto* fn = b.Function("f", ty.void_());
+ b.Append(fn->Block(), [&] { b.Return(fn); });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%void = OpTypeVoid");
}
TEST_F(SpirvWriterTest, Type_Bool) {
- writer_.Type(ty.bool_());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, bool, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%bool = OpTypeBool");
}
TEST_F(SpirvWriterTest, Type_I32) {
- writer_.Type(ty.i32());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, i32, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%int = OpTypeInt 32 1");
}
TEST_F(SpirvWriterTest, Type_U32) {
- writer_.Type(ty.u32());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, u32, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%uint = OpTypeInt 32 0");
}
TEST_F(SpirvWriterTest, Type_F32) {
- writer_.Type(ty.f32());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, f32, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%float = OpTypeFloat 32");
}
TEST_F(SpirvWriterTest, Type_F16) {
- writer_.Type(ty.f16());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, f16, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpCapability Float16");
@@ -92,56 +103,72 @@
}
TEST_F(SpirvWriterTest, Type_Vec2i) {
- writer_.Type(ty.vec2<i32>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, vec2<i32>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%v2int = OpTypeVector %int 2");
}
TEST_F(SpirvWriterTest, Type_Vec3u) {
- writer_.Type(ty.vec3<u32>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, vec3<u32>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%v3uint = OpTypeVector %uint 3");
}
TEST_F(SpirvWriterTest, Type_Vec4f) {
- writer_.Type(ty.vec4<f32>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, vec4<f32>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%v4float = OpTypeVector %float 4");
}
TEST_F(SpirvWriterTest, Type_Vec2h) {
- writer_.Type(ty.vec2<f16>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, vec2<f16>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%v2half = OpTypeVector %half 2");
}
TEST_F(SpirvWriterTest, Type_Vec4Bool) {
- writer_.Type(ty.vec4<bool>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, vec4<bool>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%v4bool = OpTypeVector %bool 4");
}
TEST_F(SpirvWriterTest, Type_Mat2x3f) {
- writer_.Type(ty.mat2x3(ty.f32()));
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, mat2x3<f32>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%mat2v3float = OpTypeMatrix %v3float 2");
}
TEST_F(SpirvWriterTest, Type_Mat4x2h) {
- writer_.Type(ty.mat4x2(ty.f16()));
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, mat4x2<f16>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%mat4v2half = OpTypeMatrix %v2half 4");
}
TEST_F(SpirvWriterTest, Type_Array_DefaultStride) {
- writer_.Type(ty.array<f32, 4>());
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, array<f32, 4>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpDecorate %_arr_float_uint_4 ArrayStride 4");
@@ -149,7 +176,9 @@
}
TEST_F(SpirvWriterTest, Type_Array_ExplicitStride) {
- writer_.Type(ty.array<f32, 4>(16));
+ b.Append(b.ir.root_block, [&] { //
+ b.Var("v", ty.ptr<private_, read_write>(ty.array<f32, 4>(16)));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpDecorate %_arr_float_uint_4 ArrayStride 16");
@@ -157,7 +186,9 @@
}
TEST_F(SpirvWriterTest, Type_Array_NestedArray) {
- writer_.Type(ty.array(ty.array<f32, 64u>(), 4u));
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, array<array<f32, 64>, 4>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpDecorate %_arr_float_uint_64 ArrayStride 4");
@@ -167,7 +198,10 @@
}
TEST_F(SpirvWriterTest, Type_RuntimeArray_DefaultStride) {
- writer_.Type(ty.array<f32>());
+ b.Append(b.ir.root_block, [&] { //
+ auto* v = b.Var<storage, array<f32>, read_write>("v");
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpDecorate %_runtimearr_float ArrayStride 4");
@@ -175,7 +209,10 @@
}
TEST_F(SpirvWriterTest, Type_RuntimeArray_ExplicitStride) {
- writer_.Type(ty.array<f32>(16));
+ b.Append(b.ir.root_block, [&] { //
+ auto* v = b.Var("v", ty.ptr<storage, read_write>(ty.array<f32>(16)));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpDecorate %_runtimearr_float ArrayStride 16");
@@ -188,7 +225,9 @@
{mod.symbols.Register("a"), ty.f32()},
{mod.symbols.Register("b"), ty.vec4<i32>()},
});
- writer_.Type(str);
+ b.Append(b.ir.root_block, [&] { //
+ b.Var("v", ty.ptr<private_, read_write>(str));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpMemberName %MyStruct 0 \"a\"");
@@ -207,7 +246,9 @@
// Matrices nested inside arrays need layout decorations on the struct member too.
{mod.symbols.Register("arr"), ty.array(ty.array(ty.mat2x4<f16>(), 4), 4)},
});
- writer_.Type(str);
+ b.Append(b.ir.root_block, [&] { //
+ b.Var("v", ty.ptr<private_, read_write>(str));
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("OpMemberDecorate %MyStruct 0 ColMajor");
@@ -218,31 +259,48 @@
}
TEST_F(SpirvWriterTest, Type_Atomic) {
- writer_.Type(ty.atomic(ty.i32()));
+ b.Append(b.ir.root_block, [&] { //
+ b.Var<private_, atomic<i32>, read_write>("v");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%int = OpTypeInt 32 1");
}
TEST_F(SpirvWriterTest, Type_Sampler) {
- writer_.Type(ty.sampler());
+ b.Append(b.ir.root_block, [&] { //
+ auto* v = b.Var("v", ty.ptr<handle, read_write>(ty.sampler()));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpTypeSampler");
+ EXPECT_INST(" = OpTypeSampler");
}
TEST_F(SpirvWriterTest, Type_SamplerComparison) {
- writer_.Type(ty.comparison_sampler());
+ b.Append(b.ir.root_block, [&] { //
+ auto* v = b.Var("v", ty.ptr<handle, read_write>(ty.comparison_sampler()));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpTypeSampler");
+ EXPECT_INST(" = OpTypeSampler");
}
TEST_F(SpirvWriterTest, Type_Samplers_Dedup) {
- auto id = writer_.Type(ty.sampler());
- EXPECT_EQ(writer_.Type(ty.comparison_sampler()), id);
+ b.Append(b.ir.root_block, [&] {
+ auto* v1 = b.Var("v1", ty.ptr<handle, read_write>(ty.sampler()));
+ auto* v2 = b.Var("v2", ty.ptr<handle, read_write>(ty.comparison_sampler()));
+ v1->SetBindingPoint(0, 1);
+ v2->SetBindingPoint(0, 2);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%3 = OpTypeSampler");
+ EXPECT_INST("%_ptr_UniformConstant_3 = OpTypePointer UniformConstant %3");
+ EXPECT_INST("%v1 = OpVariable %_ptr_UniformConstant_3 UniformConstant");
+ EXPECT_INST("%_ptr_UniformConstant_3_0 = OpTypePointer UniformConstant %3");
+ EXPECT_INST("%v2 = OpVariable %_ptr_UniformConstant_3_0 UniformConstant");
}
using Dim = core::type::TextureDimension;
@@ -255,7 +313,11 @@
using Type_SampledTexture = SpirvWriterTestWithParam<TextureCase>;
TEST_P(Type_SampledTexture, Emit) {
auto params = GetParam();
- writer_.Type(ty.Get<core::type::SampledTexture>(params.dim, MakeScalarType(params.format)));
+ b.Append(b.ir.root_block, [&] {
+ auto* v = b.Var("v", ty.ptr<handle, read_write>(ty.Get<core::type::SampledTexture>(
+ params.dim, MakeScalarType(params.format))));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(params.result);
@@ -264,30 +326,33 @@
SpirvWriterTest,
Type_SampledTexture,
testing::Values(
- TextureCase{"%1 = OpTypeImage %float 1D 0 0 0 1 Unknown", Dim::k1d, kF32},
- TextureCase{"%1 = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d, kF32},
- TextureCase{"%1 = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray, kF32},
- TextureCase{"%1 = OpTypeImage %float 3D 0 0 0 1 Unknown", Dim::k3d, kF32},
- TextureCase{"%1 = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube, kF32},
- TextureCase{"%1 = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray, kF32},
- TextureCase{"%1 = OpTypeImage %int 1D 0 0 0 1 Unknown", Dim::k1d, kI32},
- TextureCase{"%1 = OpTypeImage %int 2D 0 0 0 1 Unknown", Dim::k2d, kI32},
- TextureCase{"%1 = OpTypeImage %int 2D 0 1 0 1 Unknown", Dim::k2dArray, kI32},
- TextureCase{"%1 = OpTypeImage %int 3D 0 0 0 1 Unknown", Dim::k3d, kI32},
- TextureCase{"%1 = OpTypeImage %int Cube 0 0 0 1 Unknown", Dim::kCube, kI32},
- TextureCase{"%1 = OpTypeImage %int Cube 0 1 0 1 Unknown", Dim::kCubeArray, kI32},
- TextureCase{"%1 = OpTypeImage %uint 1D 0 0 0 1 Unknown", Dim::k1d, kU32},
- TextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 1 Unknown", Dim::k2d, kU32},
- TextureCase{"%1 = OpTypeImage %uint 2D 0 1 0 1 Unknown", Dim::k2dArray, kU32},
- TextureCase{"%1 = OpTypeImage %uint 3D 0 0 0 1 Unknown", Dim::k3d, kU32},
- TextureCase{"%1 = OpTypeImage %uint Cube 0 0 0 1 Unknown", Dim::kCube, kU32},
- TextureCase{"%1 = OpTypeImage %uint Cube 0 1 0 1 Unknown", Dim::kCubeArray, kU32}));
+ TextureCase{" = OpTypeImage %float 1D 0 0 0 1 Unknown", Dim::k1d, kF32},
+ TextureCase{" = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d, kF32},
+ TextureCase{" = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray, kF32},
+ TextureCase{" = OpTypeImage %float 3D 0 0 0 1 Unknown", Dim::k3d, kF32},
+ TextureCase{" = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube, kF32},
+ TextureCase{" = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray, kF32},
+ TextureCase{" = OpTypeImage %int 1D 0 0 0 1 Unknown", Dim::k1d, kI32},
+ TextureCase{" = OpTypeImage %int 2D 0 0 0 1 Unknown", Dim::k2d, kI32},
+ TextureCase{" = OpTypeImage %int 2D 0 1 0 1 Unknown", Dim::k2dArray, kI32},
+ TextureCase{" = OpTypeImage %int 3D 0 0 0 1 Unknown", Dim::k3d, kI32},
+ TextureCase{" = OpTypeImage %int Cube 0 0 0 1 Unknown", Dim::kCube, kI32},
+ TextureCase{" = OpTypeImage %int Cube 0 1 0 1 Unknown", Dim::kCubeArray, kI32},
+ TextureCase{" = OpTypeImage %uint 1D 0 0 0 1 Unknown", Dim::k1d, kU32},
+ TextureCase{" = OpTypeImage %uint 2D 0 0 0 1 Unknown", Dim::k2d, kU32},
+ TextureCase{" = OpTypeImage %uint 2D 0 1 0 1 Unknown", Dim::k2dArray, kU32},
+ TextureCase{" = OpTypeImage %uint 3D 0 0 0 1 Unknown", Dim::k3d, kU32},
+ TextureCase{" = OpTypeImage %uint Cube 0 0 0 1 Unknown", Dim::kCube, kU32},
+ TextureCase{" = OpTypeImage %uint Cube 0 1 0 1 Unknown", Dim::kCubeArray, kU32}));
using Type_MultisampledTexture = SpirvWriterTestWithParam<TextureCase>;
TEST_P(Type_MultisampledTexture, Emit) {
auto params = GetParam();
- writer_.Type(
- ty.Get<core::type::MultisampledTexture>(params.dim, MakeScalarType(params.format)));
+ b.Append(b.ir.root_block, [&] {
+ auto* v = b.Var("v", ty.ptr<handle, read_write>(ty.Get<core::type::MultisampledTexture>(
+ params.dim, MakeScalarType(params.format))));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(params.result);
@@ -295,14 +360,18 @@
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest,
Type_MultisampledTexture,
- testing::Values(TextureCase{"%1 = OpTypeImage %float 2D 0 0 1 1 Unknown", Dim::k2d, kF32},
- TextureCase{"%1 = OpTypeImage %int 2D 0 0 1 1 Unknown", Dim::k2d, kI32},
- TextureCase{"%1 = OpTypeImage %uint 2D 0 0 1 1 Unknown", Dim::k2d, kU32}));
+ testing::Values(TextureCase{" = OpTypeImage %float 2D 0 0 1 1 Unknown", Dim::k2d, kF32},
+ TextureCase{" = OpTypeImage %int 2D 0 0 1 1 Unknown", Dim::k2d, kI32},
+ TextureCase{" = OpTypeImage %uint 2D 0 0 1 1 Unknown", Dim::k2d, kU32}));
using Type_DepthTexture = SpirvWriterTestWithParam<TextureCase>;
TEST_P(Type_DepthTexture, Emit) {
auto params = GetParam();
- writer_.Type(ty.Get<core::type::DepthTexture>(params.dim));
+ b.Append(b.ir.root_block, [&] { //
+ auto* v =
+ b.Var("v", ty.ptr<handle, read_write>(ty.Get<core::type::DepthTexture>(params.dim)));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(params.result);
@@ -310,39 +379,63 @@
INSTANTIATE_TEST_SUITE_P(
SpirvWriterTest,
Type_DepthTexture,
- testing::Values(TextureCase{"%1 = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d},
- TextureCase{"%1 = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray},
- TextureCase{"%1 = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube},
- TextureCase{"%1 = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray}));
+ testing::Values(TextureCase{" = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d},
+ TextureCase{" = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray},
+ TextureCase{" = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube},
+ TextureCase{" = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray}));
TEST_F(SpirvWriterTest, Type_DepthTexture_DedupWithSampledTexture) {
- writer_.Type(ty.Get<core::type::SampledTexture>(Dim::k2d, ty.f32()));
- writer_.Type(ty.Get<core::type::DepthTexture>(Dim::k2d));
+ b.Append(b.ir.root_block, [&] {
+ auto* v1 = b.Var("v1", ty.ptr<handle, read_write>(
+ ty.Get<core::type::SampledTexture>(Dim::k2d, ty.f32())));
+ auto* v2 =
+ b.Var("v2", ty.ptr<handle, read_write>(ty.Get<core::type::DepthTexture>(Dim::k2d)));
+ v1->SetBindingPoint(0, 1);
+ v2->SetBindingPoint(0, 2);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
-%1 = OpTypeImage %2 2D 0 0 0 1 Unknown
-%4 = OpTypeVoid
-%5 = OpTypeFunction %4
+ EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
+%3 = OpTypeImage %4 2D 0 0 0 1 Unknown
+%2 = OpTypePointer UniformConstant %3
+%1 = OpVariable %2 UniformConstant
+%6 = OpTypePointer UniformConstant %3
+%5 = OpVariable %6 UniformConstant
+%8 = OpTypeVoid
+%9 = OpTypeFunction %8
)");
}
TEST_F(SpirvWriterTest, Type_DepthMultiSampledTexture) {
- writer_.Type(ty.Get<core::type::DepthMultisampledTexture>(Dim::k2d));
+ b.Append(b.ir.root_block, [&] {
+ auto* v = b.Var("v", ty.ptr<handle, read_write>(
+ ty.Get<core::type::DepthMultisampledTexture>(Dim::k2d)));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST("%1 = OpTypeImage %float 2D 0 0 1 1 Unknown");
+ EXPECT_INST(" = OpTypeImage %float 2D 0 0 1 1 Unknown");
}
TEST_F(SpirvWriterTest, Type_DepthMultisampledTexture_DedupWithMultisampledTexture) {
- writer_.Type(ty.Get<core::type::MultisampledTexture>(Dim::k2d, ty.f32()));
- writer_.Type(ty.Get<core::type::DepthMultisampledTexture>(Dim::k2d));
+ b.Append(b.ir.root_block, [&] {
+ auto* v1 = b.Var("v1", ty.ptr<handle, read_write>(
+ ty.Get<core::type::MultisampledTexture>(Dim::k2d, ty.f32())));
+ auto* v2 = b.Var("v2", ty.ptr<handle, read_write>(
+ ty.Get<core::type::DepthMultisampledTexture>(Dim::k2d)));
+ v1->SetBindingPoint(0, 1);
+ v2->SetBindingPoint(0, 2);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
-%1 = OpTypeImage %2 2D 0 0 1 1 Unknown
-%4 = OpTypeVoid
-%5 = OpTypeFunction %4
+ EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
+%3 = OpTypeImage %4 2D 0 0 1 1 Unknown
+%2 = OpTypePointer UniformConstant %3
+%1 = OpVariable %2 UniformConstant
+%6 = OpTypePointer UniformConstant %3
+%5 = OpVariable %6 UniformConstant
+%8 = OpTypeVoid
+%9 = OpTypeFunction %8
)");
}
@@ -355,9 +448,13 @@
using Type_StorageTexture = SpirvWriterTestWithParam<StorageTextureCase>;
TEST_P(Type_StorageTexture, Emit) {
auto params = GetParam();
- writer_.Type(ty.Get<core::type::StorageTexture>(
- params.dim, params.format, core::Access::kWrite,
- core::type::StorageTexture::SubtypeFor(params.format, mod.Types())));
+ b.Append(b.ir.root_block, [&] {
+ auto* v =
+ b.Var("v", ty.ptr<handle, read_write>(ty.Get<core::type::StorageTexture>(
+ params.dim, params.format, core::Access::kWrite,
+ core::type::StorageTexture::SubtypeFor(params.format, mod.Types()))));
+ v->SetBindingPoint(0, 0);
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST(params.result);
@@ -371,71 +468,77 @@
Type_StorageTexture,
testing::Values(
// Test all the dimensions with a single format.
- StorageTextureCase{"%1 = OpTypeImage %float 1D 0 0 0 2 R32f", //
+ StorageTextureCase{" = OpTypeImage %float 1D 0 0 0 2 R32f", //
Dim::k1d, Format::kR32Float},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 R32f", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 R32f", //
Dim::k2d, Format::kR32Float},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 1 0 2 R32f", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 1 0 2 R32f", //
Dim::k2dArray, Format::kR32Float},
- StorageTextureCase{"%1 = OpTypeImage %float 3D 0 0 0 2 R32f", //
+ StorageTextureCase{" = OpTypeImage %float 3D 0 0 0 2 R32f", //
Dim::k3d, Format::kR32Float},
// Test all the formats with 2D.
- StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 R32i", //
+ StorageTextureCase{" = OpTypeImage %int 2D 0 0 0 2 R32i", //
Dim::k2d, Format::kR32Sint},
- StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 R32u", //
+ StorageTextureCase{" = OpTypeImage %uint 2D 0 0 0 2 R32u", //
Dim::k2d, Format::kR32Uint},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rg32f", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rg32f", //
Dim::k2d, Format::kRg32Float},
- StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rg32i", //
+ StorageTextureCase{" = OpTypeImage %int 2D 0 0 0 2 Rg32i", //
Dim::k2d, Format::kRg32Sint},
- StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rg32ui", //
+ StorageTextureCase{" = OpTypeImage %uint 2D 0 0 0 2 Rg32ui", //
Dim::k2d, Format::kRg32Uint},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba16f", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba16f", //
Dim::k2d, Format::kRgba16Float},
- StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba16i", //
+ StorageTextureCase{" = OpTypeImage %int 2D 0 0 0 2 Rgba16i", //
Dim::k2d, Format::kRgba16Sint},
- StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba16ui", //
+ StorageTextureCase{" = OpTypeImage %uint 2D 0 0 0 2 Rgba16ui", //
Dim::k2d, Format::kRgba16Uint},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba32f", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba32f", //
Dim::k2d, Format::kRgba32Float},
- StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba32i", //
+ StorageTextureCase{" = OpTypeImage %int 2D 0 0 0 2 Rgba32i", //
Dim::k2d, Format::kRgba32Sint},
- StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba32ui", //
+ StorageTextureCase{" = OpTypeImage %uint 2D 0 0 0 2 Rgba32ui", //
Dim::k2d, Format::kRgba32Uint},
- StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba8i", //
+ StorageTextureCase{" = OpTypeImage %int 2D 0 0 0 2 Rgba8i", //
Dim::k2d, Format::kRgba8Sint},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba8Snorm", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba8Snorm", //
Dim::k2d, Format::kRgba8Snorm},
- StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba8ui", //
+ StorageTextureCase{" = OpTypeImage %uint 2D 0 0 0 2 Rgba8ui", //
Dim::k2d, Format::kRgba8Uint},
- StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba8", //
+ StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba8", //
Dim::k2d, Format::kRgba8Unorm}));
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpirvWriterTest, Type_Multiple) {
- writer_.Type(ty.i32());
- writer_.Type(ty.u32());
- writer_.Type(ty.f32());
- writer_.Type(ty.f16());
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, i32, read_write>("v1");
+ b.Var<private_, u32, read_write>("v2");
+ b.Var<private_, f32, read_write>("v3");
+ b.Var<private_, f16, read_write>("v4");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
- EXPECT_INST(R"(
- %int = OpTypeInt 32 1
- %uint = OpTypeInt 32 0
- %float = OpTypeFloat 32
- %half = OpTypeFloat 16
-)");
+ EXPECT_INST("%int = OpTypeInt 32 1");
+ EXPECT_INST("%uint = OpTypeInt 32 0");
+ EXPECT_INST("%float = OpTypeFloat 32");
+ EXPECT_INST("%half = OpTypeFloat 16");
}
// Test that we do not emit the same type more than once.
TEST_F(SpirvWriterTest, Type_Deduplicate) {
- auto id = writer_.Type(ty.i32());
- EXPECT_EQ(writer_.Type(ty.i32()), id);
- EXPECT_EQ(writer_.Type(ty.i32()), id);
+ b.Append(b.ir.root_block, [&] {
+ b.Var<private_, i32, read_write>("v1");
+ b.Var<private_, i32, read_write>("v2");
+ b.Var<private_, i32, read_write>("v3");
+ });
ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%_ptr_Private_int = OpTypePointer Private %int");
+ EXPECT_INST("%v1 = OpVariable %_ptr_Private_int Private %4");
+ EXPECT_INST("%v2 = OpVariable %_ptr_Private_int Private %4");
+ EXPECT_INST("%v3 = OpVariable %_ptr_Private_int Private %4");
}
} // namespace
diff --git a/src/tint/lang/spirv/writer/var_test.cc b/src/tint/lang/spirv/writer/var_test.cc
index cf945f3..4cabbce 100644
--- a/src/tint/lang/spirv/writer/var_test.cc
+++ b/src/tint/lang/spirv/writer/var_test.cc
@@ -182,8 +182,7 @@
mod.root_block->Append(b.Var("v", ty.ptr<workgroup, i32>()));
// Create a writer with the zero_init_workgroup_memory flag set to `true`.
- Printer gen(mod, true);
- ASSERT_TRUE(Generate(gen)) << Error() << output_;
+ ASSERT_TRUE(Generate({}, /* zero_init_workgroup_memory */ true)) << Error() << output_;
EXPECT_INST("%4 = OpConstantNull %int");
EXPECT_INST("%v = OpVariable %_ptr_Workgroup_int Workgroup %4");
}
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index a1db49d..1c44959 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -87,8 +87,7 @@
}
// Generate the SPIR-V code.
- auto impl = std::make_unique<Printer>(ir, zero_initialize_workgroup_memory);
- auto spirv = impl->Generate();
+ auto spirv = Print(ir, zero_initialize_workgroup_memory);
if (!spirv) {
return std::move(spirv.Failure());
}
diff --git a/src/tint/lang/spirv/writer/writer_test.cc b/src/tint/lang/spirv/writer/writer_test.cc
index abcdf58..b64c0ef 100644
--- a/src/tint/lang/spirv/writer/writer_test.cc
+++ b/src/tint/lang/spirv/writer/writer_test.cc
@@ -35,12 +35,9 @@
using namespace tint::core::number_suffixes; // NOLINT
TEST_F(SpirvWriterTest, ModuleHeader) {
- auto spirv = writer_.Generate();
- ASSERT_TRUE(spirv) << spirv.Failure();
- auto got = Disassemble(spirv.Get());
- EXPECT_THAT(got, testing::StartsWith(R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-)"));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpCapability Shader");
+ EXPECT_INST("OpMemoryModel Logical GLSL450");
}
TEST_F(SpirvWriterTest, Unreachable) {
diff --git a/src/tint/lang/wgsl/ast/module.cc b/src/tint/lang/wgsl/ast/module.cc
index 55e6142..3ee8cfd 100644
--- a/src/tint/lang/wgsl/ast/module.cc
+++ b/src/tint/lang/wgsl/ast/module.cc
@@ -92,8 +92,8 @@
[&](const ConstAssert* assertion) {
TINT_ASSERT_GENERATION_IDS_EQUAL_IF_VALID(assertion, generation_id);
const_asserts_.Push(assertion);
- },
- [&](Default) { TINT_ICE() << "Unknown global declaration type"; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void Module::AddDiagnosticDirective(const DiagnosticDirective* directive) {
diff --git a/src/tint/lang/wgsl/ast/transform/get_insertion_point.cc b/src/tint/lang/wgsl/ast/transform/get_insertion_point.cc
index 31729ef..4f0d9c7 100644
--- a/src/tint/lang/wgsl/ast/transform/get_insertion_point.cc
+++ b/src/tint/lang/wgsl/ast/transform/get_insertion_point.cc
@@ -58,12 +58,8 @@
// Cannot insert before or after continuing statement of a for-loop
return {};
- },
- [&](Default) -> RetType {
- TINT_ICE() << "expected parent of statement to be "
- "either a block or for loop";
- return {};
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
return {};
diff --git a/src/tint/lang/wgsl/ast/transform/preserve_padding.cc b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
index c61abdd..39ebb2d 100644
--- a/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
+++ b/src/tint/lang/wgsl/ast/transform/preserve_padding.cc
@@ -187,11 +187,8 @@
}
return body;
});
- },
- [&](Default) {
- TINT_ICE() << "unhandled type with padding";
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Checks if a type contains padding bytes.
diff --git a/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc
index 1c7bc25..8f2a7d8 100644
--- a/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/lang/wgsl/ast/transform/promote_side_effects_to_decl.cc
@@ -180,11 +180,8 @@
[&](const PhonyExpression* e) {
no_side_effects.insert(e);
return false;
- },
- [&](Default) {
- TINT_ICE() << "Unhandled expression type";
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
// Adds `e` to `to_hoist` for hoisting to a let later on.
@@ -344,11 +341,8 @@
[&](const PhonyExpression*) {
// Leaf
return false;
- },
- [&](Default) {
- TINT_ICE() << "Unhandled expression type";
- return false;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
// Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
@@ -527,11 +521,8 @@
},
[&](const PhonyExpression* phony) {
return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
- },
- [&](Default) {
- TINT_ICE() << "unhandled expression type: " << expr->TypeInfo().name;
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
// Inserts statements in `stmts` before `stmt`
diff --git a/src/tint/lang/wgsl/ast/transform/robustness.cc b/src/tint/lang/wgsl/ast/transform/robustness.cc
index 84c40ed..c984b29 100644
--- a/src/tint/lang/wgsl/ast/transform/robustness.cc
+++ b/src/tint/lang/wgsl/ast/transform/robustness.cc
@@ -274,12 +274,8 @@
b.Diagnostics().add_error(diag::System::Transform,
core::type::Array::kErrExpectedConstantCount);
return nullptr;
- },
- [&](Default) -> const Expression* {
- TINT_ICE() << "unhandled object type in robustness of array index: "
- << obj_type->UnwrapRef()->FriendlyName();
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Transform the program to insert additional predicate parameters to all user functions that
diff --git a/src/tint/lang/wgsl/ast/transform/single_entry_point.cc b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
index 2fa12be..44e4b86 100644
--- a/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
+++ b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
@@ -128,10 +128,8 @@
}
},
[&](const Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
- [&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled global declaration: " << decl->TypeInfo().name;
- });
+ [&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); }, //
+ TINT_ICE_ON_NO_MATCH);
}
// Clone the entry point.
diff --git a/src/tint/lang/wgsl/ast/transform/std140.cc b/src/tint/lang/wgsl/ast/transform/std140.cc
index f81b507..d933ec6 100644
--- a/src/tint/lang/wgsl/ast/transform/std140.cc
+++ b/src/tint/lang/wgsl/ast/transform/std140.cc
@@ -644,11 +644,8 @@
"_" + ConvertSuffix(mat->type());
},
[&](const core::type::F32*) { return "f32"; }, //
- [&](const core::type::F16*) { return "f16"; },
- [&](Default) {
- TINT_ICE() << "unhandled type for conversion name: " << ty->FriendlyName();
- return "";
- });
+ [&](const core::type::F16*) { return "f16"; }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Generates and returns an expression that loads the value from a std140 uniform buffer,
@@ -748,10 +745,8 @@
b.Assign(i, b.Add(i, 1_a)), //
b.Block(b.Assign(dst_el, src_el))));
stmts.Push(b.Return(var));
- },
- [&](Default) {
- TINT_ICE() << "unhandled type for conversion: " << ty->FriendlyName();
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
// Generate the function
auto ret_ty = CreateASTTypeFor(ctx, ty);
@@ -1094,10 +1089,7 @@
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, vec->type(), name};
}, //
- [&](Default) -> ExprTypeName {
- TINT_ICE() << "unhandled type for access chain: " << ty->FriendlyName();
- return {};
- });
+ TINT_ICE_ON_NO_MATCH);
}
if (auto* swizzle = std::get_if<Swizzle>(&access)) {
/// The access is a vector swizzle.
@@ -1114,10 +1106,7 @@
auto* expr = b.MemberAccessor(lhs, rhs);
return {expr, swizzle_ty, rhs};
}, //
- [&](Default) -> ExprTypeName {
- TINT_ICE() << "unhandled type for access chain: " << ty->FriendlyName();
- return {};
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// The access is a static index.
auto idx = std::get<u32>(access);
@@ -1142,10 +1131,7 @@
auto* expr = b.IndexAccessor(lhs, idx);
return {expr, vec->type(), std::to_string(idx)};
}, //
- [&](Default) -> ExprTypeName {
- TINT_ICE() << "unhandled type for access chain: " << ty->FriendlyName();
- return {};
- });
+ TINT_ICE_ON_NO_MATCH);
}
};
diff --git a/src/tint/lang/wgsl/ast/transform/unshadow.cc b/src/tint/lang/wgsl/ast/transform/unshadow.cc
index e2c4fce..eab34b0 100644
--- a/src/tint/lang/wgsl/ast/transform/unshadow.cc
+++ b/src/tint/lang/wgsl/ast/transform/unshadow.cc
@@ -87,11 +87,8 @@
},
[&](const Parameter*) { //
return b.Param(source, symbol, type, attributes);
- },
- [&](Default) {
- TINT_ICE() << "unexpected variable type: " << decl->TypeInfo().name;
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
};
bool made_changes = false;
diff --git a/src/tint/lang/wgsl/ast/traverse_expressions.h b/src/tint/lang/wgsl/ast/traverse_expressions.h
index 35e0aac..d16fcd2 100644
--- a/src/tint/lang/wgsl/ast/traverse_expressions.h
+++ b/src/tint/lang/wgsl/ast/traverse_expressions.h
@@ -167,14 +167,9 @@
push_single(unary->expr, p.depth + 1);
return true;
},
- [&](Default) {
- if (TINT_LIKELY((expr->IsAnyOf<LiteralExpression, PhonyExpression>()))) {
- return true; // Leaf expression
- }
- TINT_ICE() << "unhandled expression type: "
- << (expr ? expr->TypeInfo().name : "<null>");
- return false;
- });
+ [&](const LiteralExpression*) { return true; },
+ [&](const PhonyExpression*) { return true; }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
}
diff --git a/src/tint/lang/wgsl/helpers/append_vector.cc b/src/tint/lang/wgsl/helpers/append_vector.cc
index 4b8ce32..f0bce9b 100644
--- a/src/tint/lang/wgsl/helpers/append_vector.cc
+++ b/src/tint/lang/wgsl/helpers/append_vector.cc
@@ -104,12 +104,8 @@
[&](const core::type::I32*) { return b->ty.i32(); },
[&](const core::type::U32*) { return b->ty.u32(); },
[&](const core::type::F32*) { return b->ty.f32(); },
- [&](const core::type::Bool*) { return b->ty.bool_(); },
- [&](Default) {
- TINT_UNREACHABLE() << "unsupported vector element type: "
- << packed_el_sem_ty->TypeInfo().name;
- return ast::Type{};
- });
+ [&](const core::type::Bool*) { return b->ty.bool_(); }, //
+ TINT_ICE_ON_NO_MATCH);
auto* statement = vector_sem->Stmt();
diff --git a/src/tint/lang/wgsl/inspector/inspector.cc b/src/tint/lang/wgsl/inspector/inspector.cc
index e07b996..aa9bcfa 100644
--- a/src/tint/lang/wgsl/inspector/inspector.cc
+++ b/src/tint/lang/wgsl/inspector/inspector.cc
@@ -99,11 +99,8 @@
[&](const core::type::F32*) { return ComponentType::kF32; },
[&](const core::type::F16*) { return ComponentType::kF16; },
[&](const core::type::I32*) { return ComponentType::kI32; },
- [&](const core::type::U32*) { return ComponentType::kU32; },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled component type";
- return ComponentType::kUnknown;
- });
+ [&](const core::type::U32*) { return ComponentType::kU32; }, //
+ TINT_ICE_ON_NO_MATCH);
CompositionType compositionType;
if (auto* vec = type->As<core::type::Vector>()) {
@@ -934,11 +931,8 @@
member->Type(), //
[&](const core::type::F32*) { return PixelLocalMemberType::kF32; },
[&](const core::type::I32*) { return PixelLocalMemberType::kI32; },
- [&](const core::type::U32*) { return PixelLocalMemberType::kU32; },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled component type";
- return PixelLocalMemberType::kUnknown;
- });
+ [&](const core::type::U32*) { return PixelLocalMemberType::kU32; }, //
+ TINT_ICE_ON_NO_MATCH);
types.push_back(type);
}
diff --git a/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc b/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
index ca4d82b..cf38225 100644
--- a/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
+++ b/src/tint/lang/wgsl/reader/program_to_ir/program_to_ir.cc
@@ -258,10 +258,8 @@
},
[&](const ast::DiagnosticDirective*) {
// Ignored for now.
- },
- [&](Default) {
- add_error(decl->source, "unknown type: " + std::string(decl->TypeInfo().name));
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
if (diagnostics_.contains_errors()) {
@@ -517,11 +515,8 @@
[&](const ast::IncrementDecrementStatement* i) { EmitIncrementDecrement(i); },
[&](const ast::ConstAssert*) {
// Not emitted
- },
- [&](Default) {
- add_error(stmt->source,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void EmitAssignment(const ast::AssignmentStatement* stmt) {
@@ -982,11 +977,8 @@
impl.current_block_->Append(val);
Bind(expr, val->Result());
return nullptr;
- },
- [&](Default) {
- TINT_ICE() << "invalid accessor: " + std::string(sem->TypeInfo().name);
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!index) {
return;
@@ -1287,11 +1279,8 @@
tasks.Push([=] { Process(e->expr); });
},
[&](const ast::LiteralExpression* e) { EmitLiteral(e); },
- [&](const ast::IdentifierExpression* e) { EmitIdentifier(e); },
- [&](Default) {
- impl.add_error(expr->source,
- "Unhandled: " + std::string(expr->TypeInfo().name));
- });
+ [&](const ast::IdentifierExpression* e) { EmitIdentifier(e); }, //
+ TINT_ICE_ON_NO_MATCH);
}
};
@@ -1378,10 +1367,8 @@
// TODO(dsinclair): Probably want to store the const variable somewhere and then
// in identifier expression log an error if we ever see a const identifier. Add
// this when identifiers and variables are supported.
- },
- [&](Default) {
- add_error(var->source, "unknown variable: " + std::string(var->TypeInfo().name));
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
core::ir::Binary* BinaryOp(const core::type::Type* ty,
diff --git a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
index 6fc0326..1da3470 100644
--- a/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtins_validation_test.cc
@@ -292,6 +292,23 @@
EXPECT_EQ(r()->error(), "12:34 error: store type of @builtin(position) must be 'vec4<f32>'");
}
+TEST_F(ResolverBuiltinsValidationTest, PositionIsVec4h_Fail) {
+ // @vertex
+ // fn main() -> @builtin(position) vec4h { return vec4h(); }
+ Enable(wgsl::Extension::kF16);
+ Func("main", tint::Empty, ty.vec4<f16>(),
+ Vector{
+ Return(Call(ty.vec4<f16>())),
+ },
+ Vector{Stage(ast::PipelineStage::kVertex)},
+ Vector{
+ Builtin(Source{{12, 34}}, core::BuiltinValue::kPosition),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: store type of @builtin(position) must be 'vec4<f32>'");
+}
+
TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) {
// struct MyInputs {
// @builtin(kFragDepth) p: i32;
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph.cc b/src/tint/lang/wgsl/resolver/dependency_graph.cc
index a606a8c..a261629 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph.cc
@@ -136,11 +136,6 @@
/// A map of global name to Global
using GlobalMap = Hashmap<Symbol, Global*, 16>;
-/// Raises an ICE that a global ast::Node type was not handled by this system.
-void UnhandledNode(const ast::Node* node) {
- TINT_ICE() << "unhandled node type: " << node->TypeInfo().name;
-}
-
/// Raises an error diagnostic with the given message and source.
void AddError(diag::List& diagnostics, const std::string& msg, const Source& source) {
diagnostics.add_error(diag::System::Resolver, msg, source);
@@ -206,8 +201,10 @@
[&](const ast::Enable*) {
// Enable directives do not affect the dependency graph.
},
- [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
- [&](Default) { UnhandledNode(global->node); });
+ [&](const ast::ConstAssert* assertion) {
+ TraverseExpression(assertion->condition);
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
private:
@@ -328,12 +325,10 @@
TraverseStatement(w->body);
},
[&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
- [&](Default) {
- if (TINT_UNLIKELY((!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
- ast::DiscardStatement>()))) {
- UnhandledNode(stmt);
- }
- });
+ [&](const ast::BreakStatement*) {}, //
+ [&](const ast::ContinueStatement*) {}, //
+ [&](const ast::DiscardStatement*) {}, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Adds the symbol definition to the current scope, raising an error if two
@@ -384,69 +379,38 @@
/// Traverses the attribute, performing symbol resolution and determining
/// global dependencies.
void TraverseAttribute(const ast::Attribute* attr) {
- bool handled = Switch(
- attr,
- [&](const ast::BindingAttribute* binding) {
- TraverseExpression(binding->expr);
- return true;
- },
- [&](const ast::BuiltinAttribute* builtin) {
- TraverseExpression(builtin->builtin);
- return true;
- },
- [&](const ast::GroupAttribute* group) {
- TraverseExpression(group->expr);
- return true;
- },
- [&](const ast::IdAttribute* id) {
- TraverseExpression(id->expr);
- return true;
- },
- [&](const ast::IndexAttribute* index) {
- TraverseExpression(index->expr);
- return true;
- },
+ Switch(
+ attr, //
+ [&](const ast::BindingAttribute* binding) { TraverseExpression(binding->expr); },
+ [&](const ast::BuiltinAttribute* builtin) { TraverseExpression(builtin->builtin); },
+ [&](const ast::GroupAttribute* group) { TraverseExpression(group->expr); },
+ [&](const ast::IdAttribute* id) { TraverseExpression(id->expr); },
+ [&](const ast::IndexAttribute* index) { TraverseExpression(index->expr); },
[&](const ast::InterpolateAttribute* interpolate) {
TraverseExpression(interpolate->type);
TraverseExpression(interpolate->sampling);
- return true;
},
- [&](const ast::LocationAttribute* loc) {
- TraverseExpression(loc->expr);
- return true;
- },
- [&](const ast::StructMemberAlignAttribute* align) {
- TraverseExpression(align->expr);
- return true;
- },
- [&](const ast::StructMemberSizeAttribute* size) {
- TraverseExpression(size->expr);
- return true;
- },
+ [&](const ast::LocationAttribute* loc) { TraverseExpression(loc->expr); },
+ [&](const ast::StructMemberAlignAttribute* align) { TraverseExpression(align->expr); },
+ [&](const ast::StructMemberSizeAttribute* size) { TraverseExpression(size->expr); },
[&](const ast::WorkgroupAttribute* wg) {
TraverseExpression(wg->x);
TraverseExpression(wg->y);
TraverseExpression(wg->z);
- return true;
},
[&](const ast::InternalAttribute* i) {
for (auto* dep : i->dependencies) {
TraverseExpression(dep);
}
- return true;
+ },
+ [&](Default) {
+ if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::DiagnosticAttribute,
+ ast::InterpolateAttribute, ast::InvariantAttribute,
+ ast::MustUseAttribute, ast::StageAttribute, ast::StrideAttribute,
+ ast::StructMemberOffsetAttribute>()) {
+ TINT_ICE() << "unhandled attribute type: " << attr->TypeInfo().name;
+ }
});
- if (handled) {
- return;
- }
-
- if (attr->IsAnyOf<ast::BuiltinAttribute, ast::DiagnosticAttribute,
- ast::InterpolateAttribute, ast::InvariantAttribute, ast::MustUseAttribute,
- ast::StageAttribute, ast::StrideAttribute,
- ast::StructMemberOffsetAttribute>()) {
- return;
- }
-
- UnhandledNode(attr);
}
/// The type of builtin that a symbol could represent.
@@ -647,11 +611,8 @@
[&](const ast::Variable* var) { return var->name->symbol; },
[&](const ast::DiagnosticDirective*) { return Symbol(); },
[&](const ast::Enable*) { return Symbol(); },
- [&](const ast::ConstAssert*) { return Symbol(); },
- [&](Default) {
- UnhandledNode(node);
- return Symbol{};
- });
+ [&](const ast::ConstAssert*) { return Symbol(); }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// @param node the ast::Node of the global declaration
@@ -672,10 +633,7 @@
[&](const ast::Function*) { return "function"; }, //
[&](const ast::Variable* v) { return v->Kind(); }, //
[&](const ast::ConstAssert*) { return "const_assert"; }, //
- [&](Default) {
- UnhandledNode(node);
- return "<unknown>";
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// Traverses `module`, collecting all the global declarations and populating
@@ -923,11 +881,8 @@
},
[&](const ast::Parameter* n) { //
return "parameter '" + n->name->symbol.Name() + "'";
- },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled ast::Node: " << node->TypeInfo().name;
- return "<unknown>";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
if (auto builtin_fn = BuiltinFn(); builtin_fn != wgsl::BuiltinFn::kNone) {
return "builtin function '" + tint::ToString(builtin_fn) + "'";
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index b63ddcc..8ea7220 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -192,11 +192,8 @@
[&](const ast::TypeDecl* td) { return TypeDecl(td); },
[&](const ast::Function* func) { return Function(func); },
[&](const ast::Variable* var) { return GlobalVariable(var); },
- [&](const ast::ConstAssert* ca) { return ConstAssert(ca); },
- [&](Default) {
- TINT_UNREACHABLE() << "unhandled global declaration: " << decl->TypeInfo().name;
- return false;
- })) {
+ [&](const ast::ConstAssert* ca) { return ConstAssert(ca); }, //
+ TINT_ICE_ON_NO_MATCH)) {
return false;
}
}
@@ -241,14 +238,8 @@
[&](const ast::Var* var) { return Var(var, is_global); },
[&](const ast::Let* let) { return Let(let); },
[&](const ast::Override* override) { return Override(override); },
- [&](const ast::Const* const_) { return Const(const_, is_global); },
- [&](Default) {
- StringStream err;
- err << "Resolver::GlobalVariable() called with a unknown variable type: "
- << v->TypeInfo().name;
- AddICE(err.str(), v->source);
- return nullptr;
- });
+ [&](const ast::Const* const_) { return Const(const_, is_global); }, //
+ TINT_ICE_ON_NO_MATCH);
}
sem::Variable* Resolver::Let(const ast::Let* v) {
@@ -1515,13 +1506,8 @@
current_statement_,
/* constant_value */ nullptr,
/* has_side_effects */ false);
- },
- [&](Default) {
- StringStream err;
- err << "unhandled expression type: " << expr->TypeInfo().name;
- AddICE(err.str(), expr->source);
- return nullptr;
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
if (!sem_expr) {
return nullptr;
}
@@ -3163,11 +3149,8 @@
TINT_UNREACHABLE() << "Unhandled float literal suffix: " << f->suffix;
return nullptr;
},
- [&](const ast::BoolLiteralExpression*) { return b.create<core::type::Bool>(); },
- [&](Default) {
- TINT_UNREACHABLE() << "Unhandled literal type: " << literal->TypeInfo().name;
- return nullptr;
- });
+ [&](const ast::BoolLiteralExpression*) { return b.create<core::type::Bool>(); }, //
+ TINT_ICE_ON_NO_MATCH);
if (ty == nullptr) {
return nullptr;
diff --git a/src/tint/lang/wgsl/resolver/sem_helper.cc b/src/tint/lang/wgsl/resolver/sem_helper.cc
index 19d72b8..adb384e 100644
--- a/src/tint/lang/wgsl/resolver/sem_helper.cc
+++ b/src/tint/lang/wgsl/resolver/sem_helper.cc
@@ -129,12 +129,8 @@
[&](const UnresolvedIdentifier* ui) {
auto name = ui->Identifier()->identifier->symbol.Name();
return "unresolved identifier '" + name + "'";
- },
- [&](Default) -> std::string {
- TINT_ICE() << "unhandled sem::Expression type: "
- << (expr ? expr->TypeInfo().name : "<null>");
- return "<unknown>";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
void SemHelper::ErrorUnexpectedExprKind(
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index f419f8e..08259f7 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -1174,10 +1174,7 @@
return cf; // No impact on uniformity
},
- [&](Default) {
- TINT_ICE() << "unknown statement type: " << std::string(stmt->TypeInfo().name);
- return nullptr;
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// Process an identifier expression.
@@ -1306,11 +1303,7 @@
return std::make_pair(cf, node);
},
- [&](Default) {
- TINT_ICE() << "unknown identifier expression type: "
- << std::string(sem->TypeInfo().name);
- return std::pair<Node*, Node*>(nullptr, nullptr);
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// Process an expression.
@@ -1379,10 +1372,7 @@
return ProcessExpression(cf, u->expr, load_rule);
},
- [&](Default) {
- TINT_ICE() << "unknown expression type: " << std::string(expr->TypeInfo().name);
- return std::pair<Node*, Node*>(nullptr, nullptr);
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// @param u unary expression with op == kIndirection
@@ -1486,11 +1476,7 @@
return ProcessLValueExpression(cf, u->expr, is_partial_reference);
},
- [&](Default) {
- TINT_ICE() << "unknown lvalue expression type: "
- << std::string(expr->TypeInfo().name);
- return LValue{};
- });
+ TINT_ICE_ON_NO_MATCH);
}
/// Process a function call expression.
@@ -1608,8 +1594,8 @@
[&](const sem::ValueConversion*) {
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
function_tag = NoRestriction;
- },
- [&](Default) { TINT_ICE() << "unhandled function call target: " << name; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
cf_after->AddEdge(call_node);
@@ -1899,8 +1885,8 @@
[&](const ast::Expression* e) {
diagnostics_.add_note(diag::System::Resolver,
"result of expression may be non-uniform", e->source);
- },
- [&](Default) { TINT_ICE() << "unhandled source of non-uniformity"; });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// Generate a diagnostic message for a uniformity issue.
diff --git a/src/tint/lang/wgsl/resolver/validation_test.cc b/src/tint/lang/wgsl/resolver/validation_test.cc
index 8468e87..977b79d 100644
--- a/src/tint/lang/wgsl/resolver/validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/validation_test.cc
@@ -141,7 +141,7 @@
b.WrapInFunction(b.create<FakeStmt>());
resolver::Resolve(b);
},
- "internal compiler error: unhandled node type: tint::resolver::FakeStmt");
+ "internal compiler error: Switch() matched no cases. Type: tint::resolver::FakeStmt");
}
TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
@@ -171,7 +171,7 @@
b.WrapInFunction(b.create<FakeExpr>());
Resolver(&b).Resolve();
},
- "internal compiler error: unhandled expression type: tint::resolver::FakeExpr");
+ "internal compiler error: Switch() matched no cases. Type: tint::resolver::FakeExpr");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) {
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 4d34feb..d3c0205 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -607,11 +607,7 @@
}, //
[&](const ast::Let*) { return Let(local); }, //
[&](const ast::Const*) { return true; }, //
- [&](Default) {
- TINT_ICE() << "Validator::Variable() called with a unknown variable type: "
- << decl->TypeInfo().name;
- return false;
- });
+ TINT_ICE_ON_NO_MATCH);
}
bool Validator::GlobalVariable(
@@ -645,12 +641,8 @@
return Var(global);
},
[&](const ast::Override*) { return Override(global, override_ids); },
- [&](const ast::Const*) { return Const(global); },
- [&](Default) {
- TINT_ICE() << "Validator::GlobalVariable() called with a unknown variable type: "
- << decl->TypeInfo().name;
- return false;
- });
+ [&](const ast::Const*) { return Const(global); }, //
+ TINT_ICE_ON_NO_MATCH);
if (!ok) {
return false;
@@ -863,19 +855,21 @@
bool is_output = !is_input;
auto builtin = sem_.Get(attr)->Value();
switch (builtin) {
- case core::BuiltinValue::kPosition:
+ case core::BuiltinValue::kPosition: {
if (stage != ast::PipelineStage::kNone &&
!((is_input && stage == ast::PipelineStage::kFragment) ||
(is_output && stage == ast::PipelineStage::kVertex))) {
is_stage_mismatch = true;
}
- if (!(type->is_float_vector() && type->As<core::type::Vector>()->Width() == 4)) {
+ auto* vec = type->As<core::type::Vector>();
+ if (!(vec && vec->Width() == 4 && vec->type()->Is<core::type::F32>())) {
StringStream err;
err << "store type of @builtin(" << builtin << ") must be 'vec4<f32>'";
AddError(err.str(), attr->source);
return false;
}
break;
+ }
case core::BuiltinValue::kGlobalInvocationId:
case core::BuiltinValue::kLocalInvocationId:
case core::BuiltinValue::kNumWorkgroups:
diff --git a/src/tint/lang/wgsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/wgsl/writer/ast_printer/ast_printer.cc
index a2acd0e..7ff0c0d 100644
--- a/src/tint/lang/wgsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/wgsl/writer/ast_printer/ast_printer.cc
@@ -121,8 +121,8 @@
[&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
[&](const ast::Function* func) { return EmitFunction(func); },
[&](const ast::Variable* var) { return EmitVariable(Line(), var); },
- [&](const ast::ConstAssert* ca) { return EmitConstAssert(ca); },
- [&](Default) { TINT_UNREACHABLE(); });
+ [&](const ast::ConstAssert* ca) { return EmitConstAssert(ca); }, //
+ TINT_ICE_ON_NO_MATCH);
if (decl != program_.AST().GlobalDeclarations().Back()) {
Line();
}
@@ -157,11 +157,8 @@
EmitExpression(out, alias->type);
out << ";";
},
- [&](const ast::Struct* str) { EmitStructType(str); },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown declared type: " + std::string(ty->TypeInfo().name));
- });
+ [&](const ast::Struct* str) { EmitStructType(str); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitExpression(StringStream& out, const ast::Expression* expr) {
@@ -175,8 +172,8 @@
[&](const ast::LiteralExpression* l) { EmitLiteral(out, l); },
[&](const ast::MemberAccessorExpression* m) { EmitMemberAccessor(out, m); },
[&](const ast::PhonyExpression*) { out << "_"; },
- [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(out, u); },
- [&](Default) { diagnostics_.add_error(diag::System::Writer, "unknown expression type"); });
+ [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(out, u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitIndexAccessor(StringStream& out, const ast::IndexAccessorExpression* expr) {
@@ -253,8 +250,8 @@
<< l->suffix;
}
},
- [&](const ast::IntLiteralExpression* l) { out << l->value << l->suffix; },
- [&](Default) { diagnostics_.add_error(diag::System::Writer, "unknown literal type"); });
+ [&](const ast::IntLiteralExpression* l) { out << l->value << l->suffix; }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitIdentifier(StringStream& out, const ast::IdentifierExpression* expr) {
@@ -434,8 +431,8 @@
}
},
[&](const ast::Let*) { out << "let"; }, [&](const ast::Override*) { out << "override"; },
- [&](const ast::Const*) { out << "const"; },
- [&](Default) { TINT_ICE() << "unhandled variable type " << v->TypeInfo().name; });
+ [&](const ast::Const*) { out << "const"; }, //
+ TINT_ICE_ON_NO_MATCH);
out << " " << v->name->symbol.Name();
@@ -537,10 +534,8 @@
[&](const ast::StrideAttribute* stride) { out << "stride(" << stride->stride << ")"; },
[&](const ast::InternalAttribute* internal) {
out << "internal(" << internal->InternalName() << ")";
- },
- [&](Default) {
- TINT_ICE() << "Unsupported attribute '" << attr->TypeInfo().name << "'";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
}
@@ -679,11 +674,8 @@
[&](const ast::ReturnStatement* r) { EmitReturn(r); },
[&](const ast::ConstAssert* c) { EmitConstAssert(c); },
[&](const ast::SwitchStatement* s) { EmitSwitch(s); },
- [&](const ast::VariableDeclStatement* v) { EmitVariable(Line(), v->variable); },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- });
+ [&](const ast::VariableDeclStatement* v) { EmitVariable(Line(), v->variable); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void ASTPrinter::EmitStatements(VectorRef<const ast::Statement*> stmts) {
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index 5dc9a62..4a87f27 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -89,11 +89,6 @@
#include "src/tint/utils/math/math.h"
#include "src/tint/utils/rtti/switch.h"
-// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
-#define UNHANDLED_CASE(object_ptr) \
- TINT_UNIMPLEMENTED() << "unhandled case in Switch(): " \
- << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
-
// Helper for incrementing nesting_depth_ and then decrementing nesting_depth_ at the end
// of the scope that holds the call.
#define SCOPED_NESTING() \
@@ -202,7 +197,7 @@
tint::Switch(
inst, //
[&](core::ir::Var* var) { Var(var); }, //
- [&](Default) { UNHANDLED_CASE(inst); });
+ TINT_ICE_ON_NO_MATCH);
}
}
const ast::Function* Fn(core::ir::Function* fn) {
@@ -341,7 +336,7 @@
[&](core::ir::Unary* i) { Unary(i); }, //
[&](core::ir::Unreachable*) {}, //
[&](core::ir::Var* i) { Var(i); }, //
- [&](Default) { UNHANDLED_CASE(inst); });
+ TINT_ICE_ON_NO_MATCH);
}
void If(core::ir::If* if_) {
@@ -641,7 +636,7 @@
Bind(c->Result(), b.Bitcast(ty, args[0]), PtrKind::kPtr);
},
[&](core::ir::Discard*) { Append(b.Discard()); }, //
- [&](Default) { UNHANDLED_CASE(call); });
+ TINT_ICE_ON_NO_MATCH);
}
void Load(core::ir::Load* l) { Bind(l->Result(), Expr(l->From())); }
@@ -692,8 +687,8 @@
} else {
TINT_ICE() << "invalid index for struct type: " << index->TypeInfo().name;
}
- },
- [&](Default) { UNHANDLED_CASE(obj_ty); });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
Bind(a->Result(), expr);
}
@@ -868,11 +863,8 @@
[&](const core::type::Array*) { return composite(/* can_splat */ false); },
[&](const core::type::Vector*) { return composite(/* can_splat */ true); },
[&](const core::type::Matrix*) { return composite(/* can_splat */ false); },
- [&](const core::type::Struct*) { return composite(/* can_splat */ false); },
- [&](Default) {
- UNHANDLED_CASE(c->Type());
- return b.Expr("<error>");
- });
+ [&](const core::type::Struct*) { return composite(/* can_splat */ false); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void Enable(wgsl::Extension ext) {
@@ -969,11 +961,8 @@
[&](const core::type::Reference*) {
TINT_ICE() << "reference types should never appear in the IR";
return b.ty.i32();
- },
- [&](Default) {
- UNHANDLED_CASE(ty);
- return b.ty.i32();
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
ast::Type Struct(const core::type::Struct* s) {
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/rename_conflicts.cc b/src/tint/lang/wgsl/writer/ir_to_program/rename_conflicts.cc
index 49d2f00..a7f210f 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/rename_conflicts.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/rename_conflicts.cc
@@ -270,10 +270,8 @@
Switch(
thing, //
[&](core::ir::Value* value) { ir->SetName(value, new_name); },
- [&](core::type::Struct* str) { str->SetName(new_name); },
- [&](Default) {
- TINT_ICE() << "unhandled type for renaming: " << thing->TypeInfo().name;
- });
+ [&](core::type::Struct* str) { str->SetName(new_name); }, //
+ TINT_ICE_ON_NO_MATCH);
}
/// @return true if @p s is a builtin (non-user declared) structure.
diff --git a/src/tint/lang/wgsl/writer/syntax_tree_printer/syntax_tree_printer.cc b/src/tint/lang/wgsl/writer/syntax_tree_printer/syntax_tree_printer.cc
index 90ca7c4..b157a0c 100644
--- a/src/tint/lang/wgsl/writer/syntax_tree_printer/syntax_tree_printer.cc
+++ b/src/tint/lang/wgsl/writer/syntax_tree_printer/syntax_tree_printer.cc
@@ -103,8 +103,8 @@
[&](const ast::TypeDecl* td) { EmitTypeDecl(td); },
[&](const ast::Function* func) { EmitFunction(func); },
[&](const ast::Variable* var) { EmitVariable(var); },
- [&](const ast::ConstAssert* ca) { EmitConstAssert(ca); },
- [&](Default) { TINT_UNREACHABLE(); });
+ [&](const ast::ConstAssert* ca) { EmitConstAssert(ca); }, //
+ TINT_ICE_ON_NO_MATCH);
if (decl != program_.AST().GlobalDeclarations().Back()) {
Line();
@@ -148,11 +148,8 @@
}
Line() << "]";
},
- [&](const ast::Struct* str) { EmitStructType(str); },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown declared type: " + std::string(ty->TypeInfo().name));
- });
+ [&](const ast::Struct* str) { EmitStructType(str); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void SyntaxTreePrinter::EmitExpression(const ast::Expression* expr) {
@@ -166,8 +163,8 @@
[&](const ast::LiteralExpression* l) { EmitLiteral(l); },
[&](const ast::MemberAccessorExpression* m) { EmitMemberAccessor(m); },
[&](const ast::PhonyExpression*) { Line() << "[PhonyExpression]"; },
- [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(u); },
- [&](Default) { diagnostics_.add_error(diag::System::Writer, "unknown expression type"); });
+ [&](const ast::UnaryOpExpression* u) { EmitUnaryOp(u); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void SyntaxTreePrinter::EmitIndexAccessor(const ast::IndexAccessorExpression* expr) {
@@ -272,8 +269,8 @@
<< l->suffix;
}
},
- [&](const ast::IntLiteralExpression* l) { Line() << l->value << l->suffix; },
- [&](Default) { diagnostics_.add_error(diag::System::Writer, "unknown literal type"); });
+ [&](const ast::IntLiteralExpression* l) { Line() << l->value << l->suffix; }, //
+ TINT_ICE_ON_NO_MATCH);
}
Line() << "]";
}
@@ -499,8 +496,8 @@
},
[&](const ast::Let*) { Line() << "Let []"; },
[&](const ast::Override*) { Line() << "Override []"; },
- [&](const ast::Const*) { Line() << "Const []"; },
- [&](Default) { TINT_ICE() << "unhandled variable type " << v->TypeInfo().name; });
+ [&](const ast::Const*) { Line() << "Const []"; }, //
+ TINT_ICE_ON_NO_MATCH);
Line() << "name: " << v->name->symbol.Name();
@@ -648,10 +645,8 @@
},
[&](const ast::InternalAttribute* internal) {
Line() << "InternalAttribute [" << internal->InternalName() << "]";
- },
- [&](Default) {
- TINT_ICE() << "Unsupported attribute '" << attr->TypeInfo().name << "'";
- });
+ }, //
+ TINT_ICE_ON_NO_MATCH);
}
}
@@ -812,11 +807,8 @@
[&](const ast::ReturnStatement* r) { EmitReturn(r); },
[&](const ast::ConstAssert* c) { EmitConstAssert(c); },
[&](const ast::SwitchStatement* s) { EmitSwitch(s); },
- [&](const ast::VariableDeclStatement* v) { EmitVariable(v->variable); },
- [&](Default) {
- diagnostics_.add_error(diag::System::Writer,
- "unknown statement type: " + std::string(stmt->TypeInfo().name));
- });
+ [&](const ast::VariableDeclStatement* v) { EmitVariable(v->variable); }, //
+ TINT_ICE_ON_NO_MATCH);
}
void SyntaxTreePrinter::EmitStatements(VectorRef<const ast::Statement*> stmts) {
diff --git a/src/tint/utils/macros/compiler.h b/src/tint/utils/macros/compiler.h
index f5f7140..b1fc791 100644
--- a/src/tint/utils/macros/compiler.h
+++ b/src/tint/utils/macros/compiler.h
@@ -46,6 +46,7 @@
#define TINT_DISABLE_WARNING_FLOAT_EQUAL /* currently no-op */
#define TINT_DISABLE_WARNING_DEPRECATED __pragma(warning(disable : 4996))
#define TINT_DISABLE_WARNING_RESERVED_IDENTIFIER /* currently no-op */
+#define TINT_DISABLE_WARNING_UNUSED_VALUE /* currently no-op */
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
@@ -75,6 +76,7 @@
#define TINT_DISABLE_WARNING_DEPRECATED /* currently no-op */
#define TINT_DISABLE_WARNING_RESERVED_IDENTIFIER \
_Pragma("clang diagnostic ignored \"-Wreserved-identifier\"")
+#define TINT_DISABLE_WARNING_UNUSED_VALUE _Pragma("clang diagnostic ignored \"-Wunused-value\"")
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
@@ -103,6 +105,7 @@
#define TINT_DISABLE_WARNING_FLOAT_EQUAL /* currently no-op */
#define TINT_DISABLE_WARNING_DEPRECATED /* currently no-op */
#define TINT_DISABLE_WARNING_RESERVED_IDENTIFIER /* currently no-op */
+#define TINT_DISABLE_WARNING_UNUSED_VALUE _Pragma("GCC diagnostic ignored \"-Wunused-value\"")
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
diff --git a/src/tint/utils/result/result.h b/src/tint/utils/result/result.h
index 2e8fb8e..d040294 100644
--- a/src/tint/utils/result/result.h
+++ b/src/tint/utils/result/result.h
@@ -138,6 +138,13 @@
}
/// @returns the success value
+ /// @warning attempting to call this when the Result holds an failure will result in UB.
+ SUCCESS_TYPE* operator->() {
+ Validate();
+ return &(Get());
+ }
+
+ /// @returns the success value
/// @warning attempting to call this when the Result holds an failure value will result in UB.
const SUCCESS_TYPE& Get() const {
Validate();
diff --git a/src/tint/utils/rtti/BUILD.bazel b/src/tint/utils/rtti/BUILD.bazel
index 9b82b61..e40a7b2 100644
--- a/src/tint/utils/rtti/BUILD.bazel
+++ b/src/tint/utils/rtti/BUILD.bazel
@@ -43,9 +43,11 @@
],
hdrs = [
"castable.h",
+ "ignore.h",
"switch.h",
],
deps = [
+ "//src/tint/utils/ice",
"//src/tint/utils/macros",
"//src/tint/utils/math",
"//src/tint/utils/memory",
@@ -62,6 +64,7 @@
"switch_test.cc",
],
deps = [
+ "//src/tint/utils/ice",
"//src/tint/utils/macros",
"//src/tint/utils/math",
"//src/tint/utils/memory",
@@ -79,6 +82,7 @@
"switch_bench.cc",
],
deps = [
+ "//src/tint/utils/ice",
"//src/tint/utils/macros",
"//src/tint/utils/math",
"//src/tint/utils/memory",
diff --git a/src/tint/utils/rtti/BUILD.cmake b/src/tint/utils/rtti/BUILD.cmake
index 53d5e8f..6c58e07 100644
--- a/src/tint/utils/rtti/BUILD.cmake
+++ b/src/tint/utils/rtti/BUILD.cmake
@@ -41,10 +41,12 @@
tint_add_target(tint_utils_rtti lib
utils/rtti/castable.cc
utils/rtti/castable.h
+ utils/rtti/ignore.h
utils/rtti/switch.h
)
tint_target_add_dependencies(tint_utils_rtti lib
+ tint_utils_ice
tint_utils_macros
tint_utils_math
tint_utils_memory
@@ -61,6 +63,7 @@
)
tint_target_add_dependencies(tint_utils_rtti_test test
+ tint_utils_ice
tint_utils_macros
tint_utils_math
tint_utils_memory
@@ -81,6 +84,7 @@
)
tint_target_add_dependencies(tint_utils_rtti_bench bench
+ tint_utils_ice
tint_utils_macros
tint_utils_math
tint_utils_memory
diff --git a/src/tint/utils/rtti/BUILD.gn b/src/tint/utils/rtti/BUILD.gn
index cbc8935..8df0474 100644
--- a/src/tint/utils/rtti/BUILD.gn
+++ b/src/tint/utils/rtti/BUILD.gn
@@ -46,9 +46,11 @@
sources = [
"castable.cc",
"castable.h",
+ "ignore.h",
"switch.h",
]
deps = [
+ "${tint_src_dir}/utils/ice",
"${tint_src_dir}/utils/macros",
"${tint_src_dir}/utils/math",
"${tint_src_dir}/utils/memory",
@@ -63,6 +65,7 @@
]
deps = [
"${tint_src_dir}:gmock_and_gtest",
+ "${tint_src_dir}/utils/ice",
"${tint_src_dir}/utils/macros",
"${tint_src_dir}/utils/math",
"${tint_src_dir}/utils/memory",
@@ -76,6 +79,7 @@
sources = [ "switch_bench.cc" ]
deps = [
"${tint_src_dir}:google_benchmark",
+ "${tint_src_dir}/utils/ice",
"${tint_src_dir}/utils/macros",
"${tint_src_dir}/utils/math",
"${tint_src_dir}/utils/memory",
diff --git a/src/tint/utils/rtti/castable.h b/src/tint/utils/rtti/castable.h
index d946283..e877d23 100644
--- a/src/tint/utils/rtti/castable.h
+++ b/src/tint/utils/rtti/castable.h
@@ -35,6 +35,7 @@
#include <utility>
#include "src/tint/utils/math/crc32.h"
+#include "src/tint/utils/rtti/ignore.h"
#include "src/tint/utils/traits/traits.h"
#if defined(__clang__)
@@ -58,10 +59,6 @@
// Forward declarations
namespace tint {
class CastableBase;
-
-/// Ignore is used as a special type used for skipping over types for trait
-/// helper functions.
-class Ignore {};
} // namespace tint
namespace tint::detail {
diff --git a/src/tint/utils/rtti/ignore.h b/src/tint/utils/rtti/ignore.h
new file mode 100644
index 0000000..f4e4ddb
--- /dev/null
+++ b/src/tint/utils/rtti/ignore.h
@@ -0,0 +1,59 @@
+
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_UTILS_RTTI_IGNORE_H_
+#define SRC_TINT_UTILS_RTTI_IGNORE_H_
+
+namespace tint {
+
+/// Ignore is used as a special type used for skipping over types for trait helper functions.
+class Ignore {};
+
+} // namespace tint
+
+namespace std {
+
+/// A specialization of std::common_type where the first template argument is tint::Ignore.
+/// Used so that std::common_type will ignore template arguments of type tint::Ignore.
+template <typename T>
+struct common_type<tint::Ignore, T> {
+ /// The second template type.
+ using type = T;
+};
+
+/// A specialization of std::common_type where the second template argument is tint::Ignore.
+/// Used so that std::common_type will ignore template arguments of type tint::Ignore.
+template <typename T>
+struct common_type<T, tint::Ignore> {
+ /// The first template type.
+ using type = T;
+};
+
+} // namespace std
+
+#endif // SRC_TINT_UTILS_RTTI_IGNORE_H_
diff --git a/src/tint/utils/rtti/switch.h b/src/tint/utils/rtti/switch.h
index fc57121..24173d0 100644
--- a/src/tint/utils/rtti/switch.h
+++ b/src/tint/utils/rtti/switch.h
@@ -31,9 +31,11 @@
#include <tuple>
#include <utility>
+#include "src/tint/utils/ice/ice.h"
#include "src/tint/utils/macros/defer.h"
#include "src/tint/utils/memory/bitcast.h"
#include "src/tint/utils/rtti/castable.h"
+#include "src/tint/utils/rtti/ignore.h"
namespace tint {
@@ -48,6 +50,31 @@
/// ```
struct Default {};
+/// SwitchMustMatchCase is a flag that can be passed as the last argument to Switch() which will
+/// trigger an ICE if none of the cases matched. Cannot be used with Default.
+/// See TINT_ICE_ON_NO_MATCH
+struct SwitchMustMatchCase {
+ /// The source file that holds the TINT_ICE_ON_NO_MATCH
+ const char* file = "<unknown>";
+ /// The source line that holds the TINT_ICE_ON_NO_MATCH
+ unsigned int line = 0;
+};
+
+/// SwitchMustMatchCase is a flag that can be passed as the last argument to Switch() which will
+/// trigger an ICE if none of the cases matched. Cannot be used with Default.
+///
+/// Example:
+/// ```
+/// Switch(object,
+/// [&](TypeA*) { /* ... */ },
+/// [&](TypeB*) { /* ... */ },
+/// TINT_ICE_ON_NO_MATCH);
+/// ```
+#define TINT_ICE_ON_NO_MATCH \
+ tint::SwitchMustMatchCase { \
+ __FILE__, __LINE__ \
+ }
+
} // namespace tint
namespace tint::detail {
@@ -59,25 +86,38 @@
using SwitchCaseType =
std::remove_pointer_t<tint::traits::ParameterType<std::remove_reference_t<FN>, 0>>;
-/// Evaluates to true if the function `FN` has the signature of a Default case in a Switch().
-/// @see Switch().
-template <typename FN>
-inline constexpr bool IsDefaultCase =
- std::is_same_v<tint::traits::ParameterType<std::remove_reference_t<FN>, 0>, Default>;
-
/// Searches the list of Switch cases for a Default case, returning the index of the Default case.
/// If the a Default case is not found in the tuple, then -1 is returned.
template <typename TUPLE, std::size_t START_IDX = 0>
constexpr int IndexOfDefaultCase() {
if constexpr (START_IDX < std::tuple_size_v<TUPLE>) {
- return IsDefaultCase<std::tuple_element_t<START_IDX, TUPLE>>
- ? static_cast<int>(START_IDX)
- : IndexOfDefaultCase<TUPLE, START_IDX + 1>();
+ using T = std::decay_t<std::tuple_element_t<START_IDX, TUPLE>>;
+ if constexpr (std::is_same_v<T, SwitchMustMatchCase>) {
+ return -1;
+ } else if constexpr (std::is_same_v<tint::traits::ParameterType<T, 0>, Default>) {
+ return static_cast<int>(START_IDX);
+ } else {
+ return IndexOfDefaultCase<TUPLE, START_IDX + 1>();
+ }
} else {
return -1;
}
}
+/// Searches the list of Switch cases for a SwitchMustMatchCase flag, returning the index of the
+/// SwitchMustMatchCase case. If the a SwitchMustMatchCase case is not found in the tuple, then -1
+/// is returned.
+template <typename TUPLE, std::size_t START_IDX = 0>
+constexpr int IndexOfSwitchMustMatchCase() {
+ if constexpr (START_IDX < std::tuple_size_v<TUPLE>) {
+ using T = std::decay_t<std::tuple_element_t<START_IDX, TUPLE>>;
+ return std::is_same_v<T, SwitchMustMatchCase>
+ ? static_cast<int>(START_IDX)
+ : IndexOfSwitchMustMatchCase<TUPLE, START_IDX + 1>();
+ } else {
+ return -1;
+ }
+}
/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore.
template <typename T>
using NullptrToIgnore = std::conditional_t<std::is_same_v<T, std::nullptr_t>, tint::Ignore, T>;
@@ -140,6 +180,31 @@
REQUESTED_TYPE,
CASE_RETURN_TYPES...>::type;
+/// SwitchCaseReturnTypeImpl is the implementation of SwitchCaseReturnType
+template <typename CASE, bool is_flag>
+struct SwitchCaseReturnTypeImpl;
+
+/// SwitchCaseReturnTypeImpl specialization for non-flags.
+template <typename CASE>
+struct SwitchCaseReturnTypeImpl<CASE, /* is_flag */ false> {
+ /// The case function's return type.
+ using type = tint::traits::ReturnType<CASE>;
+};
+
+/// SwitchCaseReturnTypeImpl specialization for flags.
+template <typename CASE>
+struct SwitchCaseReturnTypeImpl<CASE, /* is_flag */ true> {
+ /// These are not functions, they have no return type.
+ using type = tint::Ignore;
+};
+
+/// Resolves to the return type for a Switch() case.
+/// If CASE is a flag like SwitchMustMatchCase, then resolves to tint::Ignore
+template <typename CASE>
+using SwitchCaseReturnType = typename SwitchCaseReturnTypeImpl<
+ CASE,
+ std::is_same_v<std::decay_t<CASE>, SwitchMustMatchCase>>::type;
+
} // namespace tint::detail
namespace tint {
@@ -156,6 +221,9 @@
/// An optional default case function with the signature `R(Default)` can be used as the last case.
/// This default case will be called if all previous cases failed to match.
///
+/// The last argument may be SwitchMustMatchCase, in which case the Switch will trigger an ICE if
+/// none of the cases matched. SwitchMustMatchCase cannot be used with a default case.
+///
/// If `object` is nullptr and a default case is provided, then the default case will be called. If
/// `object` is nullptr and no default case is provided, then no cases will be called.
///
@@ -169,35 +237,58 @@
/// [&](TypeA*) { /* ... */ },
/// [&](TypeB*) { /* ... */ },
/// [&](Default) { /* Called if object is not TypeA or TypeB */ });
+///
+/// Switch(object,
+/// [&](TypeA*) { /* ... */ },
+/// [&](TypeB*) { /* ... */ },
+/// SwitchMustMatchCase); /* ICE if object is not TypeA or TypeB */
/// ```
///
/// @param object the object who's type is used to
-/// @param cases the switch cases
+/// @param args the switch cases followed by an optional TINT_ICE_ON_NO_MATCH
/// @return the value returned by the called case. If no cases matched, then the zero value for the
/// consistent case type.
-template <typename RETURN_TYPE = tint::detail::Infer, typename T = CastableBase, typename... CASES>
-inline auto Switch(T* object, CASES&&... cases) {
- using ReturnType =
- tint::detail::SwitchReturnType<RETURN_TYPE, tint::traits::ReturnType<CASES>...>;
- static constexpr int kDefaultIndex = tint::detail::IndexOfDefaultCase<std::tuple<CASES...>>();
+template <typename RETURN_TYPE = tint::detail::Infer, typename T = CastableBase, typename... ARGS>
+inline auto Switch(T* object, ARGS&&... args) {
+ TINT_BEGIN_DISABLE_WARNING(UNUSED_VALUE);
+
+ using ArgsTuple = std::tuple<ARGS...>;
+ static constexpr int kMustMatchCaseIndex =
+ tint::detail::IndexOfSwitchMustMatchCase<ArgsTuple>();
+ static constexpr bool kHasMustMatchCase = kMustMatchCaseIndex >= 0;
+ static constexpr int kDefaultIndex = tint::detail::IndexOfDefaultCase<ArgsTuple>();
static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
+ using ReturnType =
+ tint::detail::SwitchReturnType<RETURN_TYPE, tint::detail::SwitchCaseReturnType<ARGS>...>;
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
// Static assertions
static constexpr bool kDefaultIsOK =
- kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(sizeof...(CASES) - 1);
+ kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(sizeof...(ARGS) - 1);
+ static constexpr bool kMustMatchCaseIsOK =
+ kMustMatchCaseIndex == -1 || kMustMatchCaseIndex == static_cast<int>(sizeof...(ARGS) - 1);
static constexpr bool kReturnIsOK =
kHasDefaultCase || !kHasReturnType || std::is_constructible_v<ReturnType>;
static_assert(kDefaultIsOK, "Default case must be last in Switch()");
+ static_assert(kMustMatchCaseIsOK, "SwitchMustMatchCase must be last argument in Switch()");
+ static_assert(!kHasDefaultCase || !kHasMustMatchCase,
+ "SwitchMustMatchCase cannot be used with a Default case");
static_assert(kReturnIsOK,
"Switch() requires either a Default case or a return type that is either void or "
"default-constructable");
if (!object) { // Object is nullptr, so no cases can match
- if constexpr (kHasDefaultCase) {
+ if constexpr (kHasMustMatchCase) {
+ const SwitchMustMatchCase& info = (args, ...);
+ tint::InternalCompilerError(info.file, info.line) << "Switch() passed nullptr";
+ if constexpr (kHasReturnType) {
+ return ReturnType{};
+ } else {
+ return;
+ }
+ } else if constexpr (kHasDefaultCase) {
// Evaluate default case.
- auto&& default_case =
- std::get<kDefaultIndex>(std::forward_as_tuple(std::forward<CASES>(cases)...));
+ const auto& default_case = (args, ...);
return static_cast<ReturnType>(default_case(Default{}));
} else {
// No default case, no case can match.
@@ -229,24 +320,29 @@
// `result` pointer.
auto try_case = [&](auto&& case_fn) {
using CaseFunc = std::decay_t<decltype(case_fn)>;
- using CaseType = tint::detail::SwitchCaseType<CaseFunc>;
bool success = false;
- if constexpr (std::is_same_v<CaseType, Default>) {
- if constexpr (kHasReturnType) {
- new (result) ReturnType(static_cast<ReturnType>(case_fn(Default{})));
- } else {
- case_fn(Default{});
- }
- success = true;
+ if constexpr (std::is_same_v<CaseFunc, SwitchMustMatchCase>) {
+ tint::InternalCompilerError(case_fn.file, case_fn.line)
+ << "Switch() matched no cases. Type: " << type_info.name;
} else {
- if (type_info.Is<CaseType>()) {
- auto* v = static_cast<CaseType*>(object);
+ using CaseType = tint::detail::SwitchCaseType<CaseFunc>;
+ if constexpr (std::is_same_v<CaseType, Default>) {
if constexpr (kHasReturnType) {
- new (result) ReturnType(static_cast<ReturnType>(case_fn(v)));
+ new (result) ReturnType(static_cast<ReturnType>(case_fn(Default{})));
} else {
- case_fn(v);
+ case_fn(Default{});
}
success = true;
+ } else {
+ if (type_info.Is<CaseType>()) {
+ auto* v = static_cast<CaseType*>(object);
+ if constexpr (kHasReturnType) {
+ new (result) ReturnType(static_cast<ReturnType>(case_fn(v)));
+ } else {
+ case_fn(v);
+ }
+ success = true;
+ }
}
}
return success;
@@ -254,7 +350,7 @@
// Use a logical-or fold expression to try each of the cases in turn, until one matches the
// object type or a Default is reached. `handled` is true if a case function was called.
- bool handled = ((try_case(std::forward<CASES>(cases)) || ...));
+ bool handled = ((try_case(std::forward<ARGS>(args)) || ...));
if constexpr (kHasReturnType) {
if constexpr (kHasDefaultCase) {
@@ -270,6 +366,8 @@
return ReturnType{};
}
}
+
+ TINT_END_DISABLE_WARNING(UNUSED_VALUE);
}
} // namespace tint
diff --git a/src/tint/utils/rtti/switch_test.cc b/src/tint/utils/rtti/switch_test.cc
index ed13732..f915c9e 100644
--- a/src/tint/utils/rtti/switch_test.cc
+++ b/src/tint/utils/rtti/switch_test.cc
@@ -30,6 +30,7 @@
#include <memory>
#include <string>
+#include "gtest/gtest-spi.h"
#include "gtest/gtest.h"
namespace tint {
@@ -165,6 +166,121 @@
}
}
+TEST(Castable, SwitchMustMatch_MatchedWithoutReturnValue) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+ std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+ {
+ bool ok = false;
+ Switch(
+ frog.get(), //
+ [&](Amphibian*) { ok = true; }, //
+ [&](Mammal*) {}, //
+ TINT_ICE_ON_NO_MATCH);
+ EXPECT_TRUE(ok);
+ }
+ {
+ bool ok = false;
+ Switch(
+ bear.get(), //
+ [&](Amphibian*) {}, //
+ [&](Mammal*) { ok = true; }, //
+ TINT_ICE_ON_NO_MATCH); //
+ EXPECT_TRUE(ok);
+ }
+ {
+ bool ok = false;
+ Switch(
+ gecko.get(), //
+ [&](Reptile*) { ok = true; }, //
+ [&](Amphibian*) {}, //
+ TINT_ICE_ON_NO_MATCH); //
+ EXPECT_TRUE(ok);
+ }
+}
+
+TEST(Castable, SwitchMustMatch_MatchedWithReturnValue) {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+ std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+ {
+ int res = Switch(
+ frog.get(), //
+ [&](Amphibian*) { return 1; }, //
+ [&](Mammal*) { return 0; }, //
+ TINT_ICE_ON_NO_MATCH);
+ EXPECT_EQ(res, 1);
+ }
+ {
+ int res = Switch(
+ bear.get(), //
+ [&](Amphibian*) { return 0; }, //
+ [&](Mammal*) { return 2; }, //
+ TINT_ICE_ON_NO_MATCH);
+ EXPECT_EQ(res, 2);
+ }
+ {
+ int res = Switch(
+ gecko.get(), //
+ [&](Reptile*) { return 3; }, //
+ [&](Amphibian*) { return 0; }, //
+ TINT_ICE_ON_NO_MATCH);
+ EXPECT_EQ(res, 3);
+ }
+}
+
+TEST(Castable, SwitchMustMatch_NoMatchWithoutReturnValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ Switch(
+ frog.get(), //
+ [&](Reptile*) {}, //
+ [&](Mammal*) {}, //
+ TINT_ICE_ON_NO_MATCH);
+ },
+ "internal compiler error: Switch() matched no cases. Type: Frog");
+}
+
+TEST(Castable, SwitchMustMatch_NoMatchWithReturnValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+ int res = Switch(
+ frog.get(), //
+ [&](Reptile*) { return 1; }, //
+ [&](Mammal*) { return 2; }, //
+ TINT_ICE_ON_NO_MATCH);
+ ASSERT_EQ(res, 0);
+ },
+ "internal compiler error: Switch() matched no cases. Type: Frog");
+}
+
+TEST(Castable, SwitchMustMatch_NullptrWithoutReturnValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Switch(
+ static_cast<CastableBase*>(nullptr), //
+ [&](Reptile*) {}, //
+ [&](Mammal*) {}, //
+ TINT_ICE_ON_NO_MATCH);
+ },
+ "internal compiler error: Switch() passed nullptr");
+}
+
+TEST(Castable, SwitchMustMatch_NullptrWithReturnValue) {
+ EXPECT_FATAL_FAILURE(
+ {
+ int res = Switch(
+ static_cast<CastableBase*>(nullptr), //
+ [&](Reptile*) { return 1; }, //
+ [&](Mammal*) { return 2; }, //
+ TINT_ICE_ON_NO_MATCH);
+ ASSERT_EQ(res, 0);
+ },
+ "internal compiler error: Switch() passed nullptr");
+}
+
TEST(Castable, SwitchMatchFirst) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
{