Import Tint changes from Dawn
Changes:
- 7171205a3fc23f5d7a8f971bf551ebc3adb2498c [ir][spirv-writer] Implement any builtin by James Price <jrprice@google.com>
- 6a9d6a021dee9469dcbae0f819ee2bdfedd5b06d [ir][spirv-writer] Implement derivative builtins by James Price <jrprice@google.com>
- fb1f96662716af04e17578d5d8fd591bfd61f5e9 [ir][spirv-writer] Implement binary modulo by James Price <jrprice@google.com>
- 08c60a6417096d86a6fa05e3d0a5ee02056ab72a [ir][spirv-writer] Emit bitcast instructions by James Price <jrprice@google.com>
- 5bc6329cca4eb7979a6105fc0f906b5bb88596db [ir][spirv-writer] Implement shift operations by James Price <jrprice@google.com>
- a5e7e951a6168082f7f15b39b99e91db0bb6b26b [tint][ir][tint_ir_roundtrip_fuzzer] Emit WGSL output on ... by Ben Clayton <bclayton@google.com>
- 2c298188fba80df81201b555e7c576c4332f0ba8 [tint][it][ToProgram] Correctly handle pointers by Ben Clayton <bclayton@google.com>
- 9cea3ca1b5184a08d2e1da28c90d3c0577671e83 [ir][spirv-writer] Implement cross builtin by James Price <jrprice@google.com>
- ab33ce93944a14ff4d7c75ffb342facc45c74c16 [ir][spirv-writer] Implement unary instructions by James Price <jrprice@google.com>
- 737e156a08ecc9f741c6ac573600da393304eb83 [ir][spirv-writer] Implement trig builtins by James Price <jrprice@google.com>
- aeae6a3fd8194b6fb921b44630885cd3b1070bb3 [ir][spirv-writer] Implement length builtin by James Price <jrprice@google.com>
- 37111a36b64d72d493ab0e66a382f6e9f8a81d07 [ir][spirv-writer] Implement normalize builtin by James Price <jrprice@google.com>
- 8558e4873b86783d89404300487f3e41bc854856 [ir][spirv-writer] Implement convert instructions by James Price <jrprice@google.com>
- e04cdec79537b9f577065e00c0e44e4fedf2fa0f [tint][ir] Fix new validation failures by Ben Clayton <bclayton@google.com>
- 27577970ea28d378a6a91a8f4dab5bfd44750205 [ir] Hookup IR to test runner by dan sinclair <dsinclair@chromium.org>
- e94b9bc4864c36a489df832154aac4a829276205 [ir][spirv-writer] Implement distance builtin by James Price <jrprice@google.com>
- fb89ee855e16bbce7746ab5cead2594f8e65f6de [ir][spirv-writer] Implement clamp builtin by James Price <jrprice@google.com>
- 74c9a5acd7d216f513679358c8b721b6c1bdb1e0 [ir][spirv-writer] Implement multiply by James Price <jrprice@google.com>
- 1074388b2d729f4a0fa6950dba53031fcc5cb457 [ir][spirv-writer] Implement divide by James Price <jrprice@google.com>
- 7fbc40d4bf9c63076cb85ec848eca7af2766defe [ir][spirv-writer] Emit swizzle instructions by James Price <jrprice@google.com>
- 8b64bd7619566c1e46782afa633691857beb74e6 [ir][spirv-writer] Emit texture and sampler vars by James Price <jrprice@google.com>
- b12e7191984babddbb44e0696d1e038b350d8930 [ir][spirv-writer] Emit texture and sampler types by James Price <jrprice@google.com>
- 33ef3da827b96aaae481197b8b51edd583f67dda [ir] Fix ToProgram tests by James Price <jrprice@google.com>
- 2e8692c9e3d6859167c10f7b2a1e1332b79db33e [tint][ir][ToProgram] Emit builtin calls by Ben Clayton <bclayton@google.com>
- d97942b4360769169f64a119e2e55d6afac5800d [tint][ir][ToProgram] Emit var binding attributes by Ben Clayton <bclayton@google.com>
- d95359f2d8f0161a0340b9ffbecf01624905a661 [tint][ir] Fix indexing of abstract typed constants by Ben Clayton <bclayton@google.com>
- cc4e27acadd2bde663d40d2d4472e907aece17db [tint][ir][ToProgram] Validate before emitting. by Ben Clayton <bclayton@google.com>
- faadfb1d93c2aede8a4e3645394ae44384bb4d56 [ir][validation] Walk through if/switch/loop in the valid... by dan sinclair <dsinclair@chromium.org>
- e9cd719e224eefafd9e16e6bd18750d723d11214 [ir][spirv-writer] Rework remaining unit tests by James Price <jrprice@google.com>
- fec334a929ebd13dd20b44cae6eee309758783f1 [ir][spirv-writer] Emit OpUndef when needed by James Price <jrprice@google.com>
- 96cac41fe99e292e5e719743617bdff3acb5e2a3 [ir][spirv-writer] Fix block labels for loop phis by James Price <jrprice@google.com>
- cf2bdc0ed8caa689a380d744cc827947ec43d018 [tint][ir][ToProgram] Implement Convert by Ben Clayton <bclayton@google.com>
- e17da01c3a775f1c37ec63c6d2ea2aecc9547fb3 [tint][ir][ToProgram] Implement Access by Ben Clayton <bclayton@google.com>
- 271d21552251c5dfdd50e0c6c008261edff3630f [tint][ir][ToProgram] Implement Construct by Ben Clayton <bclayton@google.com>
- a83202be6973e9e6ea2e3396199b50febd50ba09 [tint][ir][ToProgram] Add var<private>, more types. by Ben Clayton <bclayton@google.com>
- 10550d67dc682a79b8e9ab21c113d9fa9ceaa13f [tint][ir] Add roundtrip fuzzer by Ben Clayton <bclayton@google.com>
- 8232b503e3a2184f7ee60be6d4696ae7b6336138 [tint][ir] static_assert on non-deterministic instruction... by Ben Clayton <bclayton@google.com>
- 13d96766f785dd4118bc8aa86e22fcd05085077f Fixup syntax tree build. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 7171205a3fc23f5d7a8f971bf551ebc3adb2498c
Change-Id: I392f2a3160af84b3b9ab42b3345e4a4f5aca811a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/139780
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index a315fd4..38c374e 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -2042,15 +2042,19 @@
sources += [
"writer/spirv/ir/generator_impl_ir_access_test.cc",
"writer/spirv/ir/generator_impl_ir_binary_test.cc",
+ "writer/spirv/ir/generator_impl_ir_bitcast_test.cc",
"writer/spirv/ir/generator_impl_ir_builtin_test.cc",
"writer/spirv/ir/generator_impl_ir_constant_test.cc",
"writer/spirv/ir/generator_impl_ir_construct_test.cc",
+ "writer/spirv/ir/generator_impl_ir_convert_test.cc",
"writer/spirv/ir/generator_impl_ir_function_test.cc",
"writer/spirv/ir/generator_impl_ir_if_test.cc",
"writer/spirv/ir/generator_impl_ir_loop_test.cc",
"writer/spirv/ir/generator_impl_ir_switch_test.cc",
+ "writer/spirv/ir/generator_impl_ir_swizzle_test.cc",
"writer/spirv/ir/generator_impl_ir_test.cc",
"writer/spirv/ir/generator_impl_ir_type_test.cc",
+ "writer/spirv/ir/generator_impl_ir_unary_test.cc",
"writer/spirv/ir/generator_impl_ir_var_test.cc",
"writer/spirv/ir/test_helper_ir.h",
]
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 967cef2..68fe434 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1313,15 +1313,19 @@
list(APPEND TINT_TEST_SRCS
writer/spirv/ir/generator_impl_ir_access_test.cc
writer/spirv/ir/generator_impl_ir_binary_test.cc
+ writer/spirv/ir/generator_impl_ir_bitcast_test.cc
writer/spirv/ir/generator_impl_ir_builtin_test.cc
writer/spirv/ir/generator_impl_ir_constant_test.cc
writer/spirv/ir/generator_impl_ir_construct_test.cc
+ writer/spirv/ir/generator_impl_ir_convert_test.cc
writer/spirv/ir/generator_impl_ir_function_test.cc
writer/spirv/ir/generator_impl_ir_if_test.cc
writer/spirv/ir/generator_impl_ir_loop_test.cc
writer/spirv/ir/generator_impl_ir_switch_test.cc
+ writer/spirv/ir/generator_impl_ir_swizzle_test.cc
writer/spirv/ir/generator_impl_ir_test.cc
writer/spirv/ir/generator_impl_ir_type_test.cc
+ writer/spirv/ir/generator_impl_ir_unary_test.cc
writer/spirv/ir/generator_impl_ir_var_test.cc
writer/spirv/ir/test_helper_ir.h
)
diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc
index ada1603..b7b1109 100644
--- a/src/tint/cmd/main.cc
+++ b/src/tint/cmd/main.cc
@@ -619,6 +619,9 @@
// TODO(jrprice): Provide a way for the user to set non-default options.
tint::writer::msl::Options gen_options;
+#if TINT_BUILD_IR
+ gen_options.use_tint_ir = options.use_ir;
+#endif
gen_options.disable_robustness = !options.enable_robustness;
gen_options.disable_workgroup_init = options.disable_workgroup_init;
gen_options.external_texture_options.bindings_map =
diff --git a/src/tint/fuzzers/CMakeLists.txt b/src/tint/fuzzers/CMakeLists.txt
index 5640f95..c53b0ec 100644
--- a/src/tint/fuzzers/CMakeLists.txt
+++ b/src/tint/fuzzers/CMakeLists.txt
@@ -56,6 +56,10 @@
add_tint_fuzzer(tint_wgsl_reader_spv_writer_fuzzer)
endif()
+if (${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_IR})
+ add_tint_fuzzer(tint_ir_roundtrip_fuzzer)
+endif()
+
if (${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_HLSL_WRITER})
add_tint_fuzzer(tint_wgsl_reader_hlsl_writer_fuzzer)
endif()
diff --git a/src/tint/fuzzers/tint_ir_roundtrip_fuzzer.cc b/src/tint/fuzzers/tint_ir_roundtrip_fuzzer.cc
new file mode 100644
index 0000000..0ec7710
--- /dev/null
+++ b/src/tint/fuzzers/tint_ir_roundtrip_fuzzer.cc
@@ -0,0 +1,67 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+#include <string>
+#include <unordered_set>
+
+#include "src/tint/ir/from_program.h"
+#include "src/tint/ir/to_program.h"
+#include "src/tint/reader/wgsl/parser_impl.h"
+#include "src/tint/writer/wgsl/generator.h"
+
+[[noreturn]] void TintInternalCompilerErrorReporter(const tint::diag::List& diagnostics) {
+ auto printer = tint::diag::Printer::create(stderr, true);
+ tint::diag::Formatter{}.format(diagnostics, printer.get());
+ __builtin_trap();
+}
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ std::string str(reinterpret_cast<const char*>(data), size);
+
+ tint::SetInternalCompilerErrorReporter(&TintInternalCompilerErrorReporter);
+
+ tint::Source::File file("test.wgsl", str);
+
+ // Parse the wgsl, create the src program
+ tint::reader::wgsl::ParserImpl parser(&file);
+ parser.set_max_errors(1);
+ if (!parser.Parse()) {
+ return 0;
+ }
+ auto src = parser.program();
+ if (!src.IsValid()) {
+ return 0;
+ }
+
+ auto ir = tint::ir::FromProgram(&src);
+ if (!ir) {
+ std::cerr << ir.Failure() << std::endl;
+ __builtin_trap();
+ }
+
+ auto dst = tint::ir::ToProgram(ir.Get());
+ if (!dst.IsValid()) {
+#if TINT_BUILD_WGSL_WRITER
+ if (auto result = tint::writer::wgsl::Generate(&dst, {}); result.success) {
+ std::cerr << result.wgsl << std::endl << std::endl;
+ }
+#endif
+
+ std::cerr << dst.Diagnostics() << std::endl;
+ __builtin_trap();
+ }
+
+ return 0;
+}
diff --git a/src/tint/ir/bitcast.h b/src/tint/ir/bitcast.h
index 4e437ae..724a709 100644
--- a/src/tint/ir/bitcast.h
+++ b/src/tint/ir/bitcast.h
@@ -31,6 +31,9 @@
/// @param val the value being bitcast
Bitcast(InstructionResult* result, Value* val);
~Bitcast() override;
+
+ /// @returns the operand value
+ Value* Val() { return operands_[kValueOperandOffset]; }
};
} // namespace tint::ir
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 95c6df0..756ad60 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -67,6 +67,23 @@
/// Builds an ir::Module
class Builder {
+ /// Evaluates to true if T is a non-reference instruction pointer.
+ template <typename T>
+ static constexpr bool IsNonRefInstPtr =
+ std::is_pointer_v<T> && std::is_base_of_v<ir::Instruction, std::remove_pointer_t<T>>;
+
+ /// static_assert()s that ARGS contains no more than one non-reference instruction pointer.
+ /// This is used to detect patterns where C++ non-deterministic evaluation order may cause
+ /// instruction ordering bugs.
+ template <typename... ARGS>
+ static constexpr void CheckForNonDeterministicEvaluation() {
+ constexpr bool possibly_non_deterministic_eval =
+ ((IsNonRefInstPtr<ARGS> ? 1 : 0) + ...) > 1;
+ static_assert(!possibly_non_deterministic_eval,
+ "Detected possible non-deterministic ordering of instructions. "
+ "Consider hoisting Builder call arguments to separate statements.");
+ }
+
/// A helper used to enable overloads if the first type in `TYPES` is a utils::Vector or
/// utils::VectorRef.
template <typename... TYPES>
@@ -141,8 +158,8 @@
/// @returns the instruction
template <typename T>
ir::If* If(T&& condition) {
- return Append(
- ir.instructions.Create<ir::If>(Value(std::forward<T>(condition)), Block(), Block()));
+ auto* cond_val = Value(std::forward<T>(condition));
+ return Append(ir.instructions.Create<ir::If>(cond_val, Block(), Block()));
}
/// Creates a loop instruction
@@ -154,7 +171,8 @@
/// @returns the instruction
template <typename T>
ir::Switch* Switch(T&& condition) {
- return Append(ir.instructions.Create<ir::Switch>(Value(std::forward<T>(condition))));
+ auto* cond_val = Value(std::forward<T>(condition));
+ return Append(ir.instructions.Create<ir::Switch>(cond_val));
}
/// Creates a case for the switch @p s with the given selectors
@@ -201,37 +219,35 @@
/// @returns the new constant
ir::Constant* Constant(bool v) { return Constant(ir.constant_values.Get(v)); }
- /// Creates a ir::Constant for the given number
- /// @param number the number value
- /// @returns the new constant
- template <typename T, typename = std::enable_if_t<IsNumeric<T>>>
- ir::Constant* Value(T&& number) {
- return Constant(std::forward<T>(number));
- }
-
- /// Pass-through overload for nullptr values
- /// @returns nullptr
- ir::Value* Value(std::nullptr_t) { return nullptr; }
-
- /// Pass-through overload for Value()
- /// @param v the ir::Value pointer
- /// @returns @p v
- ir::Value* Value(ir::Value* v) { return v; }
-
- /// Extract the first result from the instruction
- /// @param inst the instruction
- /// @returns the result value
- ir::Value* Value(ir::Instruction* inst) {
- TINT_ASSERT(IR, inst->HasResults() && !inst->HasMultiResults());
- return inst->Result();
- }
-
- /// Creates a value from the given number
- /// @param n the number
- /// @returns the value
+ /// @param in the input value. One of: nullptr, ir::Value*, ir::Instruction* or a numeric value.
+ /// @returns an ir::Value* from the given argument.
template <typename T>
- ir::Value* Value(Number<T> n) {
- return Constant(n);
+ ir::Value* Value(T&& in) {
+ using D = std::decay_t<T>;
+ constexpr bool is_null = std::is_same_v<T, std::nullptr_t>;
+ constexpr bool is_ptr = std::is_pointer_v<D>;
+ constexpr bool is_numeric = IsNumeric<D>;
+ static_assert(is_null || is_ptr || is_numeric, "invalid argument type for Value()");
+
+ if constexpr (is_null) {
+ return nullptr;
+ } else if constexpr (is_ptr) {
+ using P = std::remove_pointer_t<D>;
+ constexpr bool is_value = std::is_base_of_v<ir::Value, P>;
+ constexpr bool is_instruction = std::is_base_of_v<ir::Instruction, P>;
+ static_assert(is_value || is_instruction, "invalid pointer type for Value()");
+
+ if constexpr (is_value) {
+ return in; /// Pass-through
+ } else if constexpr (is_instruction) {
+ /// Extract the first result from the instruction
+ TINT_ASSERT(IR, in->HasResults() && !in->HasMultiResults());
+ return in->Result();
+ }
+ } else if constexpr (is_numeric) {
+ /// Creates a value from the given number
+ return Constant(in);
+ }
}
/// Pass-through overload for Values() with vector-like argument
@@ -254,6 +270,7 @@
/// @returns a vector of ir::Value* built from transforming the arguments with Value()
template <typename... ARGS, typename = DisableIfVectorLike<ARGS...>>
auto Values(ARGS&&... args) {
+ CheckForNonDeterministicEvaluation<ARGS...>();
return utils::Vector{Value(std::forward<ARGS>(args))...};
}
@@ -265,9 +282,11 @@
/// @returns the operation
template <typename LHS, typename RHS>
ir::Binary* Binary(enum Binary::Kind kind, const type::Type* type, LHS&& lhs, RHS&& rhs) {
- return Append(ir.instructions.Create<ir::Binary>(InstructionResult(type), kind,
- Value(std::forward<LHS>(lhs)),
- Value(std::forward<RHS>(rhs))));
+ CheckForNonDeterministicEvaluation<LHS, RHS>();
+ auto* lhs_val = Value(std::forward<LHS>(lhs));
+ auto* rhs_val = Value(std::forward<RHS>(rhs));
+ return Append(
+ ir.instructions.Create<ir::Binary>(InstructionResult(type), kind, lhs_val, rhs_val));
}
/// Creates an And operation
@@ -449,8 +468,8 @@
/// @returns the operation
template <typename VAL>
ir::Unary* Unary(enum Unary::Kind kind, const type::Type* type, VAL&& val) {
- return Append(ir.instructions.Create<ir::Unary>(InstructionResult(type), kind,
- Value(std::forward<VAL>(val))));
+ auto* value = Value(std::forward<VAL>(val));
+ return Append(ir.instructions.Create<ir::Unary>(InstructionResult(type), kind, value));
}
/// Creates a Complement operation
@@ -486,8 +505,8 @@
/// @returns the instruction
template <typename VAL>
ir::Bitcast* Bitcast(const type::Type* type, VAL&& val) {
- return Append(ir.instructions.Create<ir::Bitcast>(InstructionResult(type),
- Value(std::forward<VAL>(val))));
+ auto* value = Value(std::forward<VAL>(val));
+ return Append(ir.instructions.Create<ir::Bitcast>(InstructionResult(type), value));
}
/// Creates a discard instruction
@@ -541,19 +560,21 @@
/// @returns the instruction
template <typename VAL>
ir::Load* Load(VAL&& from) {
- auto* val = Value(std::forward<VAL>(from));
+ auto* value = Value(std::forward<VAL>(from));
return Append(
- ir.instructions.Create<ir::Load>(InstructionResult(val->Type()->UnwrapPtr()), val));
+ ir.instructions.Create<ir::Load>(InstructionResult(value->Type()->UnwrapPtr()), value));
}
/// Creates a store instruction
/// @param to the expression being stored too
/// @param from the expression being stored
/// @returns the instruction
- template <typename TO, typename ARG>
- ir::Store* Store(TO&& to, ARG&& from) {
- return Append(ir.instructions.Create<ir::Store>(Value(std::forward<TO>(to)),
- Value(std::forward<ARG>(from))));
+ template <typename TO, typename FROM>
+ ir::Store* Store(TO&& to, FROM&& from) {
+ CheckForNonDeterministicEvaluation<TO, FROM>();
+ auto* to_val = Value(std::forward<TO>(to));
+ auto* from_val = Value(std::forward<FROM>(from));
+ return Append(ir.instructions.Create<ir::Store>(to_val, from_val));
}
/// Creates a new `var` declaration
@@ -585,7 +606,8 @@
return Append(ir.instructions.Create<ir::Return>(func));
}
}
- return Append(ir.instructions.Create<ir::Return>(func, Value(std::forward<ARG>(value))));
+ auto* val = Value(std::forward<ARG>(value));
+ return Append(ir.instructions.Create<ir::Return>(func, val));
}
/// Creates a loop next iteration instruction
@@ -605,8 +627,10 @@
/// @returns the instruction
template <typename CONDITION, typename... ARGS>
ir::BreakIf* BreakIf(ir::Loop* loop, CONDITION&& condition, ARGS&&... args) {
- return Append(ir.instructions.Create<ir::BreakIf>(
- Value(std::forward<CONDITION>(condition)), loop, Values(std::forward<ARGS>(args)...)));
+ CheckForNonDeterministicEvaluation<CONDITION, ARGS...>();
+ auto* cond_val = Value(std::forward<CONDITION>(condition));
+ return Append(ir.instructions.Create<ir::BreakIf>(cond_val, loop,
+ Values(std::forward<ARGS>(args)...)));
}
/// Creates a continue instruction
@@ -684,8 +708,9 @@
/// @returns the instruction
template <typename OBJ, typename... ARGS>
ir::Access* Access(const type::Type* type, OBJ&& object, ARGS&&... indices) {
- return Append(ir.instructions.Create<ir::Access>(InstructionResult(type),
- Value(std::forward<OBJ>(object)),
+ CheckForNonDeterministicEvaluation<OBJ, ARGS...>();
+ auto* obj_val = Value(std::forward<OBJ>(object));
+ return Append(ir.instructions.Create<ir::Access>(InstructionResult(type), obj_val,
Values(std::forward<ARGS>(indices)...)));
}
@@ -696,8 +721,9 @@
/// @returns the instruction
template <typename OBJ>
ir::Swizzle* Swizzle(const type::Type* type, OBJ&& object, utils::VectorRef<uint32_t> indices) {
- return Append(ir.instructions.Create<ir::Swizzle>(
- InstructionResult(type), Value(std::forward<OBJ>(object)), std::move(indices)));
+ auto* obj_val = Value(std::forward<OBJ>(object));
+ return Append(ir.instructions.Create<ir::Swizzle>(InstructionResult(type), obj_val,
+ std::move(indices)));
}
/// Creates a new `Swizzle`
@@ -709,8 +735,8 @@
ir::Swizzle* Swizzle(const type::Type* type,
OBJ&& object,
std::initializer_list<uint32_t> indices) {
- return Append(ir.instructions.Create<ir::Swizzle>(InstructionResult(type),
- Value(std::forward<OBJ>(object)),
+ auto* obj_val = Value(std::forward<OBJ>(object));
+ return Append(ir.instructions.Create<ir::Swizzle>(InstructionResult(type), obj_val,
utils::Vector<uint32_t, 4>(indices)));
}
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 38e9be4..5bb2201 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -538,18 +538,11 @@
out_ << "if ";
EmitOperand(if_, if_->Condition(), If::kConditionOperandOffset);
- bool has_true = !if_->True()->IsEmpty();
bool has_false = !if_->False()->IsEmpty();
- out_ << " [";
- if (has_true) {
- out_ << "t: %b" << IdOf(if_->True());
- }
+ out_ << " [t: %b" << IdOf(if_->True());
if (has_false) {
- if (has_true) {
- out_ << ", ";
- }
- out_ << "f: %b" << IdOf(if_->False());
+ out_ << ", f: %b" << IdOf(if_->False());
}
out_ << "]";
sm.Store(if_);
@@ -557,10 +550,12 @@
out_ << " { # " << NameOf(if_);
EmitLine();
- if (has_true) {
+ // True block is assumed to have instructions
+ {
ScopedIndent si(indent_size_);
EmitBlock(if_->True(), "true");
}
+
if (has_false) {
ScopedIndent si(indent_size_);
EmitBlock(if_->False(), "false");
@@ -584,9 +579,7 @@
if (!l->Initializer()->IsEmpty()) {
parts.Push("i: %b" + std::to_string(IdOf(l->Initializer())));
}
- if (!l->Body()->IsEmpty()) {
- parts.Push("b: %b" + std::to_string(IdOf(l->Body())));
- }
+ parts.Push("b: %b" + std::to_string(IdOf(l->Body())));
if (!l->Continuing()->IsEmpty()) {
parts.Push("c: %b" + std::to_string(IdOf(l->Continuing())));
@@ -603,7 +596,8 @@
EmitBlock(l->Initializer(), "initializer");
}
- if (!l->Body()->IsEmpty()) {
+ // Loop is assumed to always have a body
+ {
ScopedIndent si(indent_size_);
EmitBlock(l->Body(), "body");
}
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 69b80ea..19a16ea 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -867,6 +867,9 @@
std::vector<const ast::Expression*> accessors;
const ast::Expression* object = expr;
while (true) {
+ if (program_->Sem().GetVal(object)->ConstantValue()) {
+ break; // Reached a constant expression. Stop traversal.
+ }
if (auto* array = object->As<ast::IndexAccessorExpression>()) {
accessors.push_back(object);
object = array->object;
@@ -944,7 +947,7 @@
}
bool GenerateMemberAccessor(const ast::MemberAccessorExpression* expr, AccessorInfo& info) {
- auto* expr_sem = program_->Sem().Get(expr)->UnwrapLoad();
+ auto* expr_sem = program_->Sem().Get(expr)->Unwrap();
return tint::Switch(
expr_sem, //
@@ -966,7 +969,7 @@
// intermediate steps need different result types.
auto* result_type = info.result_type;
- // Emit any preceeding member/index accessors
+ // Emit any preceding member/index accessors
if (!info.indices.IsEmpty()) {
// The access chain is being split, the initial part of than will have a
// resulting type that matches the object being swizzled.
diff --git a/src/tint/ir/from_program_accessor_test.cc b/src/tint/ir/from_program_accessor_test.cc
index 3628f10..cd470c0 100644
--- a/src/tint/ir/from_program_accessor_test.cc
+++ b/src/tint/ir/from_program_accessor_test.cc
@@ -35,7 +35,7 @@
auto* a = Var("a", ty.vec3<u32>(), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", IndexAccessor(a, 2_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -58,7 +58,7 @@
auto* a = Var("a", ty.mat3x4<f32>(), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", IndexAccessor(IndexAccessor(a, 2_u), 3_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -85,7 +85,7 @@
});
auto* a = Var("a", ty.Of(s), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", MemberAccessor(a, "foo")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -121,7 +121,7 @@
});
auto* a = Var("a", ty.Of(outer), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "foo"), "bar")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -166,7 +166,7 @@
auto* expr = Decl(Let(
"b",
MemberAccessor(IndexAccessor(MemberAccessor(IndexAccessor(a, 0_u), "foo"), 1_u), "bar")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -200,7 +200,7 @@
auto* a = Var("a", ty.array<u32, 4>(), builtin::AddressSpace::kFunction);
auto* assign = Assign(IndexAccessor(a, 2_u), 0_u);
- WrapInFunction(Block(utils::Vector{Decl(a), assign}));
+ WrapInFunction(Decl(a), assign);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -223,7 +223,7 @@
auto* a = Var("a", ty.vec2<f32>(), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", MemberAccessor(a, "y")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -246,7 +246,7 @@
auto* a = Var("a", ty.vec3<f32>(), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", MemberAccessor(a, "zyxz")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -269,7 +269,7 @@
auto* a = Var("a", ty.vec3<f32>(), builtin::AddressSpace::kFunction);
auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "zyx"), "yy")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -300,7 +300,7 @@
auto* expr = Decl(Let(
"b",
IndexAccessor(MemberAccessor(MemberAccessor(MemberAccessor(a, "foo"), "zyx"), "yx"), 0_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -330,7 +330,7 @@
// let b = a[2]
auto* a = Let("a", ty.vec3<u32>(), vec(ty.u32(), 3));
auto* expr = Decl(Let("b", IndexAccessor(a, 2_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -351,7 +351,7 @@
auto* a = Let("a", ty.mat3x4<f32>(), Call<mat3x4<f32>>());
auto* expr = Decl(Let("b", IndexAccessor(IndexAccessor(a, 2_u), 3_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -376,7 +376,7 @@
});
auto* a = Let("a", ty.Of(s), Call("MyStruct"));
auto* expr = Decl(Let("b", MemberAccessor(a, "foo")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -410,7 +410,7 @@
});
auto* a = Let("a", ty.Of(outer), Call("Outer"));
auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "foo"), "bar")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -453,7 +453,7 @@
auto* expr = Decl(Let(
"b",
MemberAccessor(IndexAccessor(MemberAccessor(IndexAccessor(a, 0_u), "foo"), 1_u), "bar")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -485,7 +485,7 @@
auto* a = Let("a", ty.vec2<f32>(), vec(ty.f32(), 2));
auto* expr = Decl(Let("b", MemberAccessor(a, "y")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -506,7 +506,7 @@
auto* a = Let("a", ty.vec3<f32>(), vec(ty.f32(), 3));
auto* expr = Decl(Let("b", MemberAccessor(a, "zyxz")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -527,7 +527,7 @@
auto* a = Let("a", ty.vec3<f32>(), vec(ty.f32(), 3));
auto* expr = Decl(Let("b", MemberAccessor(MemberAccessor(a, "zyx"), "yy")));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -556,7 +556,7 @@
auto* expr = Decl(Let(
"b",
IndexAccessor(MemberAccessor(MemberAccessor(MemberAccessor(a, "foo"), "zyx"), "yx"), 0_u)));
- WrapInFunction(Block(utils::Vector{Decl(a), expr}));
+ WrapInFunction(Decl(a), expr);
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
@@ -579,5 +579,53 @@
)");
}
+TEST_F(IR_FromProgramAccessorTest, Accessor_Const_AbstractVectorWithIndex) {
+ // const v = vec3(1, 2, 3);
+ // let i = 1;
+ // var b = v[i];
+
+ auto* v = Const("v", Call<vec3<Infer>>(1_a, 2_a, 3_a));
+ auto* i = Let("i", Expr(1_i));
+ auto* b = Var("b", IndexAccessor("v", "i"));
+ WrapInFunction(v, i, b);
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %2:i32 = access vec3<i32>(1i, 2i, 3i), 1i
+ %b:ptr<function, i32, read_write> = var, %2
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_FromProgramAccessorTest, Accessor_Const_AbstractVectorWithSwizzleAndIndex) {
+ // const v = vec3(1, 2, 3);
+ // let i = 1;
+ // var b = v.rg[i];
+
+ auto* v = Const("v", Call<vec3<Infer>>(1_a, 2_a, 3_a));
+ auto* i = Let("i", Expr(1_i));
+ auto* b = Var("b", IndexAccessor(MemberAccessor("v", "rg"), "i"));
+ WrapInFunction(v, i, b);
+
+ auto m = Build();
+ ASSERT_TRUE(m) << (!m ? m.Failure() : "");
+
+ EXPECT_EQ(Disassemble(m.Get()),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ %2:i32 = access vec2<i32>(1i, 2i), 1i
+ %b:ptr<function, i32, read_write> = var, %2
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 3ca56a6..3e30cfb 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -18,13 +18,17 @@
#include <tuple>
#include <utility>
+#include "src/tint/constant/splat.h"
#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/break_if.h"
+#include "src/tint/ir/builtin_call.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/construct.h"
#include "src/tint/ir/continue.h"
+#include "src/tint/ir/convert.h"
#include "src/tint/ir/exit_if.h"
#include "src/tint/ir/exit_loop.h"
#include "src/tint/ir/exit_switch.h"
@@ -39,7 +43,9 @@
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/unary.h"
+#include "src/tint/ir/unreachable.h"
#include "src/tint/ir/user_call.h"
+#include "src/tint/ir/validate.h"
#include "src/tint/ir/var.h"
#include "src/tint/program_builder.h"
#include "src/tint/switch.h"
@@ -73,16 +79,20 @@
namespace {
-/// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by its
-/// single place of usage. Attempting to use this value a second time should result in an ICE.
-struct ConsumedValue {};
-
class State {
public:
explicit State(Module& m) : mod(m) {}
Program Run() {
- // TODO(crbug.com/tint/1902): Emit root block
+ if (auto res = Validate(mod); !res) {
+ // IR module failed validation.
+ b.Diagnostics() = res.Failure();
+ return Program{std::move(b)};
+ }
+
+ if (mod.root_block) {
+ RootBlock(mod.root_block);
+ }
// TODO(crbug.com/tint/1902): Emit user-declared types
for (auto* fn : mod.functions) {
Fn(fn);
@@ -91,21 +101,43 @@
}
private:
+ /// The AST representation for an IR pointer type
+ enum class PtrKind {
+ kPtr, // IR pointer is represented in the AST as a pointer
+ kRef, // IR pointer is represented in the AST as a reference
+ };
+
/// The source IR module
Module& mod;
/// The target ProgramBuilder
ProgramBuilder b;
- using ValueBinding = std::variant<Symbol, const ast::Expression*, ConsumedValue>;
+ /// The structure for a value held by a 'let', 'var' or parameter.
+ struct VariableValue {
+ Symbol name; // Name of the variable
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
- /// A hashmap of value to one of:
- /// * Symbol - Name of 'let' (non-inlinable value), 'var' or parameter.
- /// * ast::Expression* - single use, inlined expression.
- /// * ConsumedValue - a special value used to indicate that the value has already been
- /// consumed.
+ /// The structure for an inlined value
+ struct InlinedValue {
+ const ast::Expression* expr = nullptr;
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
+
+ /// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by
+ /// its single place of usage. Attempting to use this value a second time should result in an
+ /// ICE.
+ struct ConsumedValue {};
+
+ using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
+
+ /// IR values to their representation
utils::Hashmap<Value*, ValueBinding, 32> bindings_;
+ /// Names for values
+ utils::Hashmap<Value*, Symbol, 32> names_;
+
/// The nesting depth of the currently generated AST
/// 0 is module scope
/// 1 is root-level function scope
@@ -119,21 +151,36 @@
/// The current switch case block
ir::Block* current_switch_case_ = nullptr;
- // Values that can be inlined.
+ /// Values that can be inlined.
utils::Hashset<ir::Value*, 64> can_inline_;
+ /// Set of enable directives emitted.
+ utils::Hashset<builtin::Extension, 4> enables_;
+
+ /// Map of struct to output program name.
+ utils::Hashmap<const type::Struct*, Symbol, 8> structs_;
+
+ void RootBlock(ir::Block* root) {
+ for (auto* inst : *root) {
+ tint::Switch(
+ inst, //
+ [&](ir::Var* var) { Var(var); }, //
+ [&](Default) { UNHANDLED_CASE(inst); });
+ }
+ }
const ast::Function* Fn(ir::Function* fn) {
SCOPED_NESTING();
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
static constexpr size_t N = decltype(ast::Function::params)::static_length;
auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
- auto name = BindName(param);
+ auto name = NameFor(param);
+ Bind(param, name, PtrKind::kPtr);
auto ty = Type(param->Type());
return b.Param(name, ty);
});
- auto name = BindName(fn);
+ auto name = NameFor(fn);
auto ret_ty = Type(fn->ReturnType());
auto* body = Block(fn->Block());
utils::Vector<const ast::Attribute*, 1> attrs{};
@@ -231,22 +278,24 @@
void Instruction(ir::Instruction* inst) {
tint::Switch(
inst, //
+ [&](ir::Access* i) { Access(i); }, //
[&](ir::Binary* i) { Binary(i); }, //
[&](ir::BreakIf* i) { BreakIf(i); }, //
[&](ir::Call* i) { Call(i); }, //
+ [&](ir::Continue*) {}, //
[&](ir::ExitIf*) {}, //
- [&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
[&](ir::ExitLoop* i) { ExitLoop(i); }, //
+ [&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
[&](ir::If* i) { If(i); }, //
[&](ir::Load* l) { Load(l); }, //
[&](ir::Loop* l) { Loop(l); }, //
+ [&](ir::NextIteration*) {}, //
[&](ir::Return* i) { Return(i); }, //
[&](ir::Store* i) { Store(i); }, //
[&](ir::Switch* i) { Switch(i); }, //
[&](ir::Unary* u) { Unary(u); }, //
+ [&](ir::Unreachable*) {}, //
[&](ir::Var* i) { Var(i); }, //
- [&](ir::NextIteration*) {}, //
- [&](ir::Continue*) {}, //
[&](Default) { UNHANDLED_CASE(inst); });
}
@@ -393,22 +442,30 @@
void Var(ir::Var* var) {
auto* val = var->Result();
- Symbol name = BindName(val);
+ Symbol name = NameFor(var->Result());
+ Bind(var->Result(), name, PtrKind::kRef);
auto* ptr = As<type::Pointer>(val->Type());
auto ty = Type(ptr->StoreType());
+
+ utils::Vector<const ast::Attribute*, 4> attrs;
+ if (auto bp = var->BindingPoint()) {
+ attrs.Push(b.Group(AInt(bp->group)));
+ attrs.Push(b.Binding(AInt(bp->binding)));
+ }
+
const ast::Expression* init = nullptr;
if (var->Initializer()) {
init = Expr(var->Initializer());
}
switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
- Append(b.Decl(b.Var(name, ty, init)));
+ Append(b.Decl(b.Var(name, ty, init, std::move(attrs))));
return;
case builtin::AddressSpace::kStorage:
- Append(b.Decl(b.Var(name, ty, init, ptr->Access(), ptr->AddressSpace())));
+ b.GlobalVar(name, ty, init, ptr->Access(), ptr->AddressSpace(), std::move(attrs));
return;
default:
- Append(b.Decl(b.Var(name, ty, init, ptr->AddressSpace())));
+ b.GlobalVar(name, ty, init, ptr->AddressSpace(), std::move(attrs));
return;
}
}
@@ -420,16 +477,35 @@
}
void Call(ir::Call* call) {
- auto args = utils::Transform<2>(call->Args(), [&](ir::Value* arg) { return Expr(arg); });
+ auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) {
+ // Pointer-like arguments are passed by pointer, never reference.
+ return Expr(arg, PtrKind::kPtr);
+ });
tint::Switch(
call, //
[&](ir::UserCall* c) {
- auto* expr = b.Call(BindName(c->Func()), std::move(args));
+ auto* expr = b.Call(NameFor(c->Func()), std::move(args));
if (!call->HasResults() || call->Result()->Usages().IsEmpty()) {
Append(b.CallStmt(expr));
return;
}
- Bind(c->Result(), expr);
+ Bind(c->Result(), expr, PtrKind::kPtr);
+ },
+ [&](ir::BuiltinCall* c) {
+ auto* expr = b.Call(c->Func(), std::move(args));
+ if (!call->HasResults() || call->Result()->Usages().IsEmpty()) {
+ Append(b.CallStmt(expr));
+ return;
+ }
+ Bind(c->Result(), expr, PtrKind::kPtr);
+ },
+ [&](ir::Construct* c) {
+ auto ty = Type(c->Result()->Type());
+ Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
+ },
+ [&](ir::Convert* c) {
+ auto ty = Type(c->Result()->Type());
+ Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
},
[&](Default) { UNHANDLED_CASE(call); });
}
@@ -449,6 +525,57 @@
Bind(u->Result(), expr);
}
+ void Access(ir::Access* a) {
+ auto* expr = Expr(a->Object());
+ auto* obj_ty = a->Object()->Type()->UnwrapPtr();
+ for (auto* index : a->Indices()) {
+ tint::Switch(
+ obj_ty,
+ [&](const type::Vector* vec) {
+ TINT_DEFER(obj_ty = vec->type());
+ if (auto* c = index->As<ir::Constant>()) {
+ switch (c->Value()->ValueAs<int>()) {
+ case 0:
+ expr = b.MemberAccessor(expr, "x");
+ return;
+ case 1:
+ expr = b.MemberAccessor(expr, "y");
+ return;
+ case 2:
+ expr = b.MemberAccessor(expr, "z");
+ return;
+ case 3:
+ expr = b.MemberAccessor(expr, "w");
+ return;
+ }
+ }
+ expr = b.IndexAccessor(expr, Expr(index));
+ },
+ [&](const type::Matrix* mat) {
+ obj_ty = mat->ColumnType();
+ expr = b.IndexAccessor(expr, Expr(index));
+ },
+ [&](const type::Array* arr) {
+ obj_ty = arr->ElemType();
+ expr = b.IndexAccessor(expr, Expr(index));
+ },
+ [&](const type::Struct* s) {
+ if (auto* c = index->As<ir::Constant>()) {
+ auto i = c->Value()->ValueAs<uint32_t>();
+ TINT_ASSERT_OR_RETURN(IR, i < s->Members().Length());
+ auto* member = s->Members()[i];
+ obj_ty = member->Type();
+ expr = b.IndexAccessor(expr, member->Name().NameView());
+ } else {
+ TINT_ICE(IR, b.Diagnostics())
+ << "invalid index for struct type: " << index->TypeInfo().name;
+ }
+ },
+ [&](Default) { UNHANDLED_CASE(obj_ty); });
+ }
+ Bind(a->Result(), expr);
+ }
+
void Binary(ir::Binary* e) {
if (e->Kind() == ir::Binary::Kind::kEqual) {
auto* rhs = e->RHS()->As<ir::Constant>();
@@ -516,31 +643,35 @@
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
- const ast::Expression* Expr(ir::Value* value) {
- return tint::Switch(
- value, //
- [&](ir::Constant* c) { return Constant(c); }, //
- [&](Default) -> const ast::Expression* {
+ const ast::Expression* Expr(ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef) {
+ using ExprAndPtrKind = std::pair<const ast::Expression*, PtrKind>;
+
+ auto [expr, got_ptr_kind] = tint::Switch(
+ value,
+ [&](ir::Constant* c) -> ExprAndPtrKind {
+ return {Constant(c), PtrKind::kRef};
+ },
+ [&](Default) -> ExprAndPtrKind {
auto lookup = bindings_.Find(value);
if (TINT_UNLIKELY(!lookup)) {
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << (value ? value->TypeInfo().name : "null")
<< ") value has no expression";
- return b.Expr("<error>");
+ return {};
}
return std::visit(
- [&](auto&& got) -> const ast::Expression* {
+ [&](auto&& got) -> ExprAndPtrKind {
using T = std::decay_t<decltype(got)>;
- if constexpr (std::is_same_v<T, Symbol>) {
- return b.Expr(got); // var, let or parameter.
+ if constexpr (std::is_same_v<T, VariableValue>) {
+ return {b.Expr(got.name), got.ptr_kind};
}
- if constexpr (std::is_same_v<T, const ast::Expression*>) {
+ if constexpr (std::is_same_v<T, InlinedValue>) {
// Single use (inlined) expression.
// Mark the bindings_ map entry as consumed.
*lookup = ConsumedValue{};
- return got;
+ return {got.expr, got.ptr_kind};
}
if constexpr (std::is_same_v<T, ConsumedValue>) {
@@ -550,28 +681,68 @@
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << value->TypeInfo().name << ") has unhandled value";
}
- return b.Expr("<error>");
+ return {};
},
*lookup);
});
+
+ if (!expr) {
+ return b.Expr("<error>");
+ }
+
+ if (value->Type()->Is<type::Pointer>()) {
+ return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
+ }
+
+ return expr;
}
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
- const ast::Expression* Constant(ir::Constant* c) {
+ const ast::Expression* Constant(ir::Constant* c) { return Constant(c->Value()); }
+
+ const ast::Expression* Constant(const constant::Value* c) {
+ auto composite = [&](bool can_splat) {
+ auto ty = Type(c->Type());
+ if (c->AllZero()) {
+ return b.Call(ty);
+ }
+ if (can_splat && c->Is<constant::Splat>()) {
+ return b.Call(ty, Constant(c->Index(0)));
+ }
+
+ utils::Vector<const ast::Expression*, 8> els;
+ for (size_t i = 0, n = c->NumElements(); i < n; i++) {
+ els.Push(Constant(c->Index(i)));
+ }
+ return b.Call(ty, std::move(els));
+ };
return tint::Switch(
c->Type(), //
- [&](const type::I32*) { return b.Expr(c->Value()->ValueAs<i32>()); },
- [&](const type::U32*) { return b.Expr(c->Value()->ValueAs<u32>()); },
- [&](const type::F32*) { return b.Expr(c->Value()->ValueAs<f32>()); },
- [&](const type::F16*) { return b.Expr(c->Value()->ValueAs<f16>()); },
- [&](const type::Bool*) { return b.Expr(c->Value()->ValueAs<bool>()); },
+ [&](const type::I32*) { return b.Expr(c->ValueAs<i32>()); },
+ [&](const type::U32*) { return b.Expr(c->ValueAs<u32>()); },
+ [&](const type::F32*) { return b.Expr(c->ValueAs<f32>()); },
+ [&](const type::F16*) {
+ Enable(builtin::Extension::kF16);
+ return b.Expr(c->ValueAs<f16>());
+ },
+ [&](const type::Bool*) { return b.Expr(c->ValueAs<bool>()); },
+ [&](const type::Array*) { return composite(/* can_splat */ false); },
+ [&](const type::Vector*) { return composite(/* can_splat */ true); },
+ [&](const type::Matrix*) { return composite(/* can_splat */ false); },
+ [&](const type::Struct*) { return composite(/* can_splat */ false); },
[&](Default) {
- UNHANDLED_CASE(c);
+ UNHANDLED_CASE(c->Type());
return b.Expr("<error>");
});
}
+ void Enable(builtin::Extension ext) {
+ if (enables_.Add(ext)) {
+ b.Enable(ext);
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////////////////////////
// Types
//
@@ -591,8 +762,11 @@
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
[&](const type::U32*) { return b.ty.u32(); }, //
- [&](const type::F16*) { return b.ty.f16(); }, //
- [&](const type::F32*) { return b.ty.f32(); }, //
+ [&](const type::F16*) {
+ Enable(builtin::Extension::kF16);
+ return b.ty.f16();
+ },
+ [&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
[&](const type::Matrix* m) {
return b.ty.mat(Type(m->type()), m->columns(), m->rows());
@@ -622,7 +796,7 @@
}
return b.ty.array(el, u32(count.value()), std::move(attrs));
},
- [&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
+ [&](const type::Struct* s) { return Struct(s); },
[&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
@@ -661,16 +835,44 @@
});
}
+ ast::Type Struct(const type::Struct* s) {
+ auto n = structs_.GetOrCreate(s, [&] {
+ auto members = utils::Transform<8>(s->Members(), [&](const type::StructMember* m) {
+ auto ty = Type(m->Type());
+ // TODO(crbug.com/tint/1902): Emit structure member attributes
+ utils::Vector<const ast::Attribute*, 2> attrs;
+ return b.Member(m->Name().NameView(), ty, std::move(attrs));
+ });
+
+ // TODO(crbug.com/tint/1902): Emit structure attributes
+ utils::Vector<const ast::Attribute*, 2> attrs;
+
+ auto name = b.Symbols().New(s->Name().NameView());
+ b.Structure(name, std::move(members), std::move(attrs));
+ return name;
+ });
+
+ return b.ty(n);
+ }
+
+ const ast::Expression* ToPtrKind(const ast::Expression* in, PtrKind got, PtrKind want) {
+ if (want == PtrKind::kRef && got == PtrKind::kPtr) {
+ return b.Deref(in);
+ }
+ if (want == PtrKind::kPtr && got == PtrKind::kRef) {
+ return b.AddressOf(in);
+ }
+ return in;
+ }
+
////////////////////////////////////////////////////////////////////////////////////////////////
// Bindings
////////////////////////////////////////////////////////////////////////////////////////////////
- /// Creates and returns a new, unique name for the given value, or returns the previously
- /// created name.
- /// @return the value's name
- Symbol BindName(Value* value, std::string_view suggested = {}) {
- TINT_ASSERT(IR, value);
- auto& existing = bindings_.GetOrCreate(value, [&] {
+ /// @returns the AST name for the given value, creating and returning a new name on the first
+ /// call.
+ Symbol NameFor(Value* value, std::string_view suggested = {}) {
+ return names_.GetOrCreate(value, [&] {
if (!suggested.empty()) {
return b.Symbols().New(suggested);
}
@@ -679,27 +881,41 @@
}
return b.Symbols().New("v");
});
- if (auto* name = std::get_if<Symbol>(&existing); TINT_LIKELY(name)) {
- return *name;
- }
-
- TINT_ICE(IR, b.Diagnostics()) << "BindName(" << value->TypeInfo().name
- << ") called on value that has non-name binding";
- return {};
}
- template <typename T>
- void Bind(ir::Value* value, const T* expr) {
+ /// Associates the IR value @p value with the AST expression @p expr.
+ /// @p ptr_kind defines how pointer values are represented by @p expr.
+ void Bind(ir::Value* value, const ast::Expression* expr, PtrKind ptr_kind = PtrKind::kRef) {
TINT_ASSERT(IR, value);
if (can_inline_.Remove(value)) {
// Value will be inlined at its place of usage.
- bool added = bindings_.Add(value, expr);
- if (TINT_UNLIKELY(!added)) {
- TINT_ICE(IR, b.Diagnostics())
- << "Bind(" << value->TypeInfo().name << ") called twice for same node";
+ if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
+ return;
}
} else {
- Append(b.Decl(b.Let(BindName(value), expr)));
+ if (value->Type()->Is<type::Pointer>()) {
+ expr = ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
+ }
+ Symbol name = NameFor(value);
+ Append(b.Decl(b.Let(name, expr)));
+ Bind(value, name, PtrKind::kPtr);
+ return;
+ }
+
+ TINT_ICE(IR, b.Diagnostics())
+ << "Bind(" << value->TypeInfo().name << ") called twice for same value";
+ }
+
+ /// Associates the IR value @p value with the AST 'var', 'let' or parameter with the name @p
+ /// name.
+ /// @p ptr_kind defines how pointer values are represented by @p expr.
+ void Bind(ir::Value* value, Symbol name, PtrKind ptr_kind) {
+ TINT_ASSERT(IR, value);
+
+ bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
+ if (TINT_UNLIKELY(!added)) {
+ TINT_ICE(IR, b.Diagnostics())
+ << "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
}
diff --git a/src/tint/ir/to_program_inlining_test.cc b/src/tint/ir/to_program_inlining_test.cc
index 72b9d39..a5d84ee 100644
--- a/src/tint/ir/to_program_inlining_test.cc
+++ b/src/tint/ir/to_program_inlining_test.cc
@@ -465,6 +465,7 @@
TEST_F(IRToProgramInliningTest, LoadVar_ThenCallVoidFn_ThenUseLoad) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn = b.Function("f", ty.i32());
b.With(fn->Block(), [&] {
@@ -665,7 +666,10 @@
b.Store(var, 1_i);
auto* load = b.Load(var);
auto* if_ = b.If(true);
- b.With(if_->True(), [&] { b.Store(var, 2_i); });
+ b.With(if_->True(), [&] {
+ b.Store(var, 2_i);
+ b.ExitIf(if_);
+ });
b.Return(fn, load);
});
@@ -789,7 +793,10 @@
auto* load = b.Load(var);
auto* switch_ = b.Switch(1_i);
auto* case_ = b.Case(switch_, {Switch::CaseSelector{}});
- b.With(case_, [&] { b.Store(var, 2_i); });
+ b.With(case_, [&] {
+ b.Store(var, 2_i);
+ b.ExitSwitch(switch_);
+ });
b.Return(fn, load);
});
@@ -817,7 +824,10 @@
auto* var = b.Var(ty.ptr<function, i32>());
auto* v = b.Add(ty.i32(), 1_i, 2_i);
auto* loop = b.Loop();
- b.With(loop->Initializer(), [&] { b.Store(var, v); });
+ b.With(loop->Initializer(), [&] {
+ b.Store(var, v);
+ b.NextIteration(loop);
+ });
b.With(loop->Body(), [&] { b.ExitLoop(loop); });
b.Return(fn, 0_i);
});
@@ -843,7 +853,10 @@
auto* v_1 = b.Load(var);
auto* v_2 = b.Add(ty.i32(), v_1, 2_i);
auto* loop = b.Loop();
- b.With(loop->Initializer(), [&] { b.Store(var, v_2); });
+ b.With(loop->Initializer(), [&] {
+ b.Store(var, v_2);
+ b.NextIteration(loop);
+ });
b.With(loop->Body(), [&] { b.ExitLoop(loop); });
b.Return(fn, 0_i);
});
@@ -870,7 +883,10 @@
b.Store(var, 1_i);
auto* load = b.Load(var);
auto* loop = b.Loop();
- b.With(loop->Initializer(), [&] { b.Store(var, 2_i); });
+ b.With(loop->Initializer(), [&] {
+ b.Store(var, 2_i);
+ b.NextIteration(loop);
+ });
b.With(loop->Body(), [&] { b.ExitLoop(loop); });
b.Return(fn, load);
});
@@ -1021,11 +1037,7 @@
b.With(loop->Body(), [&] { b.Continue(loop); });
b.With(loop->Continuing(), [&] {
b.Store(var, 2_i);
- b.NextIteration(loop);
- });
- b.With(loop->Body(), [&] {
- b.Store(var, 2_i);
- b.ExitLoop(loop);
+ b.BreakIf(loop, true);
});
b.Return(fn, load);
});
@@ -1036,11 +1048,10 @@
v = 1i;
let v_1 = v;
loop {
- v = 2i;
- break;
continuing {
v = 2i;
+ break if true;
}
}
return v_1;
@@ -1056,6 +1067,7 @@
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
auto* load = b.Load(var);
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
b.Store(var, b.Add(ty.i32(), load, 1_i));
b.ExitLoop(loop);
@@ -1088,6 +1100,7 @@
auto* loop = b.Loop();
b.With(loop->Initializer(), [&] {
auto* load = b.Load(var);
+ b.NextIteration(loop);
b.With(loop->Body(), [&] { b.Continue(loop); });
b.With(loop->Continuing(), [&] {
b.Store(var, b.Add(ty.i32(), load, 1_i));
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index c5b482a..14bc8a1 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -104,6 +104,462 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Function Call
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, FnCall_NoArgs_NoRet) {
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+ a();
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FnCall_NoArgs_Ret_i32) {
+ Test(R"(
+fn a() -> i32 {
+ return 1i;
+}
+
+fn b() {
+ var i : i32 = a();
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FnCall_3Args_NoRet) {
+ Test(R"(
+fn a(x : i32, y : u32, z : f32) {
+}
+
+fn b() {
+ a(1i, 2u, 3.0f);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FnCall_3Args_Ret_f32) {
+ Test(R"(
+fn a(x : i32, y : u32, z : f32) -> f32 {
+ return z;
+}
+
+fn b() {
+ var v : f32 = a(1i, 2u, 3.0f);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FnCall_PtrArgs) {
+ Test(R"(
+var<private> y : i32 = 2i;
+
+fn a(px : ptr<function, i32>, py : ptr<private, i32>) -> i32 {
+ return (*(px) + *(py));
+}
+
+fn b() -> i32 {
+ var x : i32 = 1i;
+ return a(&(x), &(y));
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Builtin Call
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, BuiltinCall_Stmt) {
+ Test(R"(
+fn f() {
+ workgroupBarrier();
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BuiltinCall_Expr) {
+ Test(R"(
+fn f(a : i32, b : i32) {
+ var i : i32 = max(a, b);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, BuiltinCall_PtrArg) {
+ Test(R"(
+var<workgroup> v : bool;
+
+fn foo() -> bool {
+ return workgroupUniformLoad(&(v));
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Type Construct
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_i32) {
+ Test(R"(
+fn f(i : i32) {
+ var v : i32 = i32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_u32) {
+ Test(R"(
+fn f(i : u32) {
+ var v : u32 = u32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_f32) {
+ Test(R"(
+fn f(i : f32) {
+ var v : f32 = f32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_bool) {
+ Test(R"(
+fn f(i : bool) {
+ var v : bool = bool(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_struct) {
+ Test(R"(
+struct S {
+ a : i32,
+ b : u32,
+ c : f32,
+}
+
+fn f(a : i32, b : u32, c : f32) {
+ var v : S = S(a, b, c);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_array) {
+ Test(R"(
+fn f(i : i32) {
+ var v : array<i32, 3u> = array<i32, 3u>(i, i, i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_vec3i_Splat) {
+ Test(R"(
+fn f(i : i32) {
+ var v : vec3<i32> = vec3<i32>(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_vec3i_Scalars) {
+ Test(R"(
+fn f(i : i32) {
+ var v : vec3<i32> = vec3<i32>(i, i, i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_mat2x3f_Scalars) {
+ Test(R"(
+fn f(i : f32) {
+ var v : mat2x3<f32> = mat2x3<f32>(i, i, i, i, i, i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConstruct_mat2x3f_Columns) {
+ Test(R"(
+fn f(i : f32) {
+ var v : mat2x3<f32> = mat2x3<f32>(vec3<f32>(i, i, i), vec3<f32>(i, i, i));
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Type Convert
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, TypeConvert_i32_to_u32) {
+ Test(R"(
+fn f(i : i32) {
+ var v : u32 = u32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_u32_to_f32) {
+ Test(R"(
+fn f(i : u32) {
+ var v : f32 = f32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_f32_to_i32) {
+ Test(R"(
+fn f(i : f32) {
+ var v : i32 = i32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_bool_to_u32) {
+ Test(R"(
+fn f(i : bool) {
+ var v : u32 = u32(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_vec3i_to_vec3u) {
+ Test(R"(
+fn f(i : vec3<i32>) {
+ var v : vec3<u32> = vec3<u32>(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_vec3u_to_vec3f) {
+ Test(R"(
+fn f(i : vec3<u32>) {
+ var v : vec3<f32> = vec3<f32>(i);
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, TypeConvert_mat2x3f_to_mat2x3h) {
+ Test(R"(
+enable f16;
+
+fn f(i : mat2x3<f32>) {
+ var v : mat2x3<f16> = mat2x3<f16>(i);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Access
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, Access_Value_vec3f_1) {
+ Test(R"(
+fn f(v : vec3<f32>) -> f32 {
+ return v[1];
+}
+)",
+ R"(
+fn f(v : vec3<f32>) -> f32 {
+ return v.y;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_vec3f_1) {
+ Test(R"(
+var<private> v : vec3<f32>;
+
+fn f() -> f32 {
+ return v[1];
+}
+)",
+ R"(
+var<private> v : vec3<f32>;
+
+fn f() -> f32 {
+ return v.y;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_vec3f_z) {
+ Test(R"(
+fn f(v : vec3<f32>) -> f32 {
+ return v.z;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_vec3f_z) {
+ Test(R"(
+var<private> v : vec3<f32>;
+
+fn f() -> f32 {
+ return v.z;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_vec3f_g) {
+ Test(R"(
+fn f(v : vec3<f32>) -> f32 {
+ return v.g;
+}
+)",
+ R"(
+fn f(v : vec3<f32>) -> f32 {
+ return v.y;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_vec3f_g) {
+ Test(R"(
+var<private> v : vec3<f32>;
+
+fn f() -> f32 {
+ return v.g;
+}
+)",
+ R"(
+var<private> v : vec3<f32>;
+
+fn f() -> f32 {
+ return v.y;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_vec3f_i) {
+ Test(R"(
+fn f(v : vec3<f32>, i : i32) -> f32 {
+ return v[i];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_vec3f_i) {
+ Test(R"(
+var<private> v : vec3<f32>;
+
+fn f(i : i32) -> f32 {
+ return v[i];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_mat3x2f_1_0) {
+ Test(R"(
+fn f(m : mat3x2<f32>) -> f32 {
+ return m[1][0];
+}
+)",
+ R"(
+fn f(m : mat3x2<f32>) -> f32 {
+ return m[1i].x;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_mat3x2f_1_0) {
+ Test(R"(
+var<private> m : mat3x2<f32>;
+
+fn f() -> f32 {
+ return m[1][0];
+}
+)",
+ R"(
+var<private> m : mat3x2<f32>;
+
+fn f() -> f32 {
+ return m[1i].x;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_mat3x2f_u_0) {
+ Test(R"(
+fn f(m : mat3x2<f32>, u : u32) -> f32 {
+ return m[u][0];
+}
+)",
+ R"(
+fn f(m : mat3x2<f32>, u : u32) -> f32 {
+ return m[u].x;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_mat3x2f_u_0) {
+ Test(R"(
+var<private> m : mat3x2<f32>;
+
+fn f(u : u32) -> f32 {
+ return m[u][0];
+}
+)",
+ R"(
+var<private> m : mat3x2<f32>;
+
+fn f(u : u32) -> f32 {
+ return m[u].x;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_mat3x2f_u_i) {
+ Test(R"(
+fn f(m : mat3x2<f32>, u : u32, i : i32) -> f32 {
+ return m[u][i];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_mat3x2f_u_i) {
+ Test(R"(
+var<private> m : mat3x2<f32>;
+
+fn f(u : u32, i : i32) -> f32 {
+ return m[u][i];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_array_0u) {
+ Test(R"(
+fn f(a : array<i32, 4u>) -> i32 {
+ return a[0u];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_array_0u) {
+ Test(R"(
+var<private> a : array<i32, 4u>;
+
+fn f() -> i32 {
+ return a[0u];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Value_array_i) {
+ Test(R"(
+fn f(a : array<i32, 4u>, i : i32) -> i32 {
+ return a[i];
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Access_Ref_array_i) {
+ Test(R"(
+var<private> a : array<i32, 4u>;
+
+fn f(i : i32) -> i32 {
+ return a[i];
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// Unary ops
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramRoundtripTest, UnaryOp_Negate) {
@@ -625,6 +1081,164 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Module-scope var
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_i32) {
+ Test("var<private> v : i32 = 1i;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_u32) {
+ Test("var<private> v : u32 = 1u;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_f32) {
+ Test("var<private> v : f32 = 1.0f;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_f16) {
+ Test(R"(
+enable f16;
+
+var<private> v : f16 = 1.0h;
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_bool) {
+ Test("var<private> v : bool = true;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_array_NoArgs) {
+ Test("var<private> v : array<i32, 4u> = array<i32, 4u>();");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_array_Zero) {
+ Test("var<private> v : array<i32, 4u> = array<i32, 4u>(0i, 0i, 0i, 0i);",
+ "var<private> v : array<i32, 4u> = array<i32, 4u>();");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_array_SameValue) {
+ Test("var<private> v : array<i32, 4u> = array<i32, 4u>(4i, 4i, 4i, 4i);");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_array_DifferentValues) {
+ Test("var<private> v : array<i32, 4u> = array<i32, 4u>(1i, 2i, 3i, 4i);");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_struct_NoArgs) {
+ Test(R"(
+struct S {
+ i : i32,
+ u : u32,
+ f : f32,
+}
+
+var<private> s : S = S();
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_struct_Zero) {
+ Test(R"(
+struct S {
+ i : i32,
+ u : u32,
+ f : f32,
+}
+
+var<private> s : S = S(0i, 0u, 0f);
+)",
+ R"(
+struct S {
+ i : i32,
+ u : u32,
+ f : f32,
+}
+
+var<private> s : S = S();
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_struct_SameValue) {
+ Test(R"(
+struct S {
+ a : i32,
+ b : i32,
+ c : i32,
+}
+
+var<private> s : S = S(4i, 4i, 4i);
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_struct_DifferentValues) {
+ Test(R"(
+struct S {
+ a : i32,
+ b : i32,
+ c : i32,
+}
+
+var<private> s : S = S(1i, 2i, 3i);
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_vec3f_NoArgs) {
+ Test("var<private> v : vec3<f32> = vec3<f32>();");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_vec3f_Zero) {
+ Test("var<private> v : vec3<f32> = vec3<f32>(0f);",
+ "var<private> v : vec3<f32> = vec3<f32>();");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_vec3f_Splat) {
+ Test("var<private> v : vec3<f32> = vec3<f32>(1.0f);");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_vec3f_Scalars) {
+ Test("var<private> v : vec3<f32> = vec3<f32>(1.0f, 2.0f, 3.0f);");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_mat2x3f_NoArgs) {
+ Test("var<private> v : mat2x3<f32> = mat2x3<f32>();");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_mat2x3f_Scalars_SameValue) {
+ Test("var<private> v : mat2x3<f32> = mat2x3<f32>(4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f);",
+ "var<private> v : mat2x3<f32> = mat2x3<f32>(vec3<f32>(4.0f), vec3<f32>(4.0f));");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_mat2x3f_Scalars) {
+ Test("var<private> v : mat2x3<f32> = mat2x3<f32>(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f);",
+ "var<private> v : mat2x3<f32> = "
+ "mat2x3<f32>(vec3<f32>(1.0f, 2.0f, 3.0f), vec3<f32>(4.0f, 5.0f, 6.0f));");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_mat2x3f_Columns) {
+ Test(
+ "var<private> v : mat2x3<f32> = "
+ "mat2x3<f32>(vec3<f32>(1.0f, 2.0f, 3.0f), vec3<f32>(4.0f, 5.0f, 6.0f));");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Private_mat2x3f_Columns_SameValue) {
+ Test(
+ "var<private> v : mat2x3<f32> = "
+ "mat2x3<f32>(vec3<f32>(4.0f, 4.0f, 4.0f), vec3<f32>(4.0f, 4.0f, 4.0f));",
+ "var<private> v : mat2x3<f32> = mat2x3<f32>(vec3<f32>(4.0f), vec3<f32>(4.0f));");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_Uniform_vec4i) {
+ Test("@group(10) @binding(20) var<uniform> v : vec4<i32>;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_StorageRead_u32) {
+ Test("@group(10) @binding(20) var<storage, read> v : u32;");
+}
+
+TEST_F(IRToProgramRoundtripTest, ModuleScopeVar_StorageReadWrite_i32) {
+ Test("@group(10) @binding(20) var<storage, read_write> v : i32;");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// Function-scope var
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramRoundtripTest, FunctionScopeVar_i32) {
@@ -654,6 +1268,31 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Function-scope let
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_i32) {
+ Test(R"(
+fn f(i : i32) -> i32 {
+ let a = (42i + i);
+ let b = (24i + i);
+ let c = (a + b);
+ return c;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_ptr) {
+ Test(R"(
+fn f() -> i32 {
+ var a : array<i32, 3u>;
+ let b = &(a[1i]);
+ let c = *(b);
+ return c;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// If
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramRoundtripTest, If_CallFn) {
diff --git a/src/tint/ir/to_program_test.cc b/src/tint/ir/to_program_test.cc
index 00b5681..4d27531 100644
--- a/src/tint/ir/to_program_test.cc
+++ b/src/tint/ir/to_program_test.cc
@@ -61,15 +61,6 @@
EXPECT_WGSL("");
}
-TEST_F(IRToProgramTest, SingleFunction_Empty) {
- b.Function("f", ty.void_());
-
- EXPECT_WGSL(R"(
-fn f() {
-}
-)");
-}
-
TEST_F(IRToProgramTest, SingleFunction_Return) {
auto* fn = b.Function("f", ty.void_());
@@ -1288,6 +1279,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Add(ty.i32(), b.Load(v), 1_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1304,6 +1297,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Subtract(ty.i32(), b.Load(v), 1_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1320,6 +1315,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Add(ty.i32(), b.Load(v), 8_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1336,6 +1333,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Subtract(ty.i32(), b.Load(v), 8_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1352,6 +1351,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Multiply(ty.i32(), b.Load(v), 8_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1368,6 +1369,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Divide(ty.i32(), b.Load(v), 8_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1384,6 +1387,8 @@
b.With(fn->Block(), [&] {
auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, b.Xor(ty.i32(), b.Load(v), 8_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1443,6 +1448,8 @@
b.With(fn->Block(), [&] { //
b.Var("i", ty.ptr<function, i32>());
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1458,6 +1465,8 @@
b.With(fn->Block(), [&] {
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(42_i));
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1481,6 +1490,8 @@
auto* lb = b.Load(vb)->Result();
auto* vc = b.Var("c", ty.ptr<function, i32>());
vc->SetInitializer(lb);
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1496,7 +1507,8 @@
// If
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, If_CallFn) {
- auto* a = b.Function("a", ty.void_());
+ auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn = b.Function("f", ty.void_());
auto* cond = b.FunctionParam("cond", ty.bool_());
@@ -1505,9 +1517,11 @@
b.With(fn->Block(), [&] {
auto* if_ = b.If(cond);
b.With(if_->True(), [&] {
- b.Call(ty.void_(), a);
+ b.Call(ty.void_(), fn_a);
b.ExitIf(if_);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1530,6 +1544,8 @@
b.With(fn->Block(), [&] {
auto if_ = b.If(cond);
b.With(if_->True(), [&] { b.Return(fn); });
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1549,6 +1565,7 @@
cond->SetInitializer(b.Constant(true));
auto if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn, 42_i); });
+
b.Return(fn, 10_i);
});
@@ -1565,8 +1582,10 @@
TEST_F(IRToProgramTest, If_CallFn_Else_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn_b = b.Function("b", ty.void_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b); });
auto* fn = b.Function("f", ty.void_());
auto* cond = b.FunctionParam("cond", ty.bool_());
@@ -1582,6 +1601,8 @@
b.Call(ty.void_(), fn_b);
b.ExitIf(if_);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1610,6 +1631,8 @@
auto if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn, 1.0_f); });
b.With(if_->False(), [&] { b.Return(fn, 2.0_f); });
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1626,8 +1649,10 @@
TEST_F(IRToProgramTest, If_Return_u32_Else_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn_b = b.Function("b", ty.void_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b); });
auto* fn = b.Function("f", ty.u32());
@@ -1666,10 +1691,13 @@
TEST_F(IRToProgramTest, If_CallFn_ElseIf_CallFn) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn_b = b.Function("b", ty.void_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b); });
auto* fn_c = b.Function("c", ty.void_());
+ b.With(fn_c->Block(), [&] { b.Return(fn_c); });
auto* fn = b.Function("f", ty.void_());
@@ -1690,6 +1718,8 @@
b.ExitIf(if1);
});
b.Call(ty.void_(), fn_c);
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1738,19 +1768,23 @@
b.With(if2->True(), [&] {
b.Call(ty.void_(), x, 1_i);
b.ExitIf(if2);
- b.With(if2->False(), [&] {
- auto* if3 = b.If(pc);
- b.With(if3->True(), [&] {
- b.Call(ty.void_(), x, 2_i);
- b.ExitIf(if3);
- });
- b.With(if3->False(), [&] {
- b.Call(ty.void_(), x, 3_i);
- b.ExitIf(if3);
- });
- });
});
+ b.With(if2->False(), [&] {
+ auto* if3 = b.If(pc);
+ b.With(if3->True(), [&] {
+ b.Call(ty.void_(), x, 2_i);
+ b.ExitIf(if3);
+ });
+ b.With(if3->False(), [&] {
+ b.Call(ty.void_(), x, 3_i);
+ b.ExitIf(if3);
+ });
+ b.ExitIf(if2);
+ });
+ b.ExitIf(if1);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
fn x(i : i32) -> bool {
@@ -1776,6 +1810,7 @@
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramTest, Switch_Default) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn = b.Function("f", ty.void_());
@@ -1788,6 +1823,8 @@
b.Call(ty.void_(), fn_a);
b.ExitSwitch(s);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1807,10 +1844,13 @@
TEST_F(IRToProgramTest, Switch_3_Cases) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn_b = b.Function("b", ty.void_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b); });
auto* fn_c = b.Function("c", ty.void_());
+ b.With(fn_c->Block(), [&] { b.Return(fn_c); });
auto* fn = b.Function("f", ty.void_());
@@ -1836,6 +1876,8 @@
b.Call(ty.void_(), fn_c);
b.ExitSwitch(s);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -1867,6 +1909,7 @@
TEST_F(IRToProgramTest, Switch_3_Cases_AllReturn) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
auto* fn = b.Function("f", ty.void_());
@@ -1912,13 +1955,15 @@
TEST_F(IRToProgramTest, Switch_Nested) {
auto* fn_a = b.Function("a", ty.void_());
+ b.With(fn_a->Block(), [&] { b.Return(fn_a); });
- b.Function("b", ty.void_());
+ auto* fn_b = b.Function("b", ty.void_());
+ b.With(fn_b->Block(), [&] { b.Return(fn_b); });
auto* fn_c = b.Function("c", ty.void_());
+ b.With(fn_c->Block(), [&] { b.Return(fn_c); });
auto* fn = b.Function("f", ty.void_());
-
b.With(fn->Block(), [&] {
auto* v1 = b.Var("v1", ty.ptr<function, i32>());
v1->SetInitializer(b.Constant(42_i));
@@ -1946,11 +1991,15 @@
Switch::CaseSelector{},
}),
[&] { b.Return(fn); });
+
+ b.ExitSwitch(s1);
});
b.With(b.Case(s1, {Switch::CaseSelector{b.Constant(2_i)}}), [&] {
b.Call(ty.void_(), fn_c);
b.ExitSwitch(s1);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
fn a() {
@@ -1998,15 +2047,22 @@
b.With(loop->Initializer(), [&] {
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
});
- b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i));
+ b.NextIteration(loop);
+ });
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2030,9 +2086,15 @@
auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
});
- b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i));
+ b.NextIteration(loop);
+ });
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2053,13 +2115,17 @@
b.With(loop->Initializer(), [&] {
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
});
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2084,6 +2150,7 @@
b.With(loop->Initializer(), [&] {
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
auto* if1 = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
@@ -2093,9 +2160,13 @@
auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
b.With(if2->True(), [&] { b.Return(fn, 1_i); });
b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+ b.Unreachable();
});
- b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i));
+ b.NextIteration(loop);
+ });
});
b.Return(fn, 3_i);
@@ -2141,9 +2212,14 @@
auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
b.With(if2->True(), [&] { b.Return(fn, 1_i); });
b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+
+ b.Continue(loop);
});
- b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i));
+ b.NextIteration(loop);
+ });
b.Return(fn, 3_i);
});
@@ -2181,6 +2257,7 @@
b.With(loop->Initializer(), [&] {
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(b.Constant(0_i));
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
auto* if1 = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
@@ -2190,8 +2267,11 @@
auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
b.With(if2->True(), [&] { b.Return(fn, 1_i); });
b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+
+ b.NextIteration(loop);
});
});
+
b.Return(fn, 3_i);
});
@@ -2228,6 +2308,7 @@
auto* n_0 = b.Call(ty.i32(), fn_n, 0_i)->Result();
auto* i = b.Var("i", ty.ptr<function, i32>());
i->SetInitializer(n_0);
+ b.NextIteration(loop);
b.With(loop->Body(), [&] {
auto* load = b.Load(i);
@@ -2235,10 +2316,17 @@
auto* if_ = b.If(b.LessThan(ty.bool_(), load, call));
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
+
+ b.Continue(loop);
});
- b.With(loop->Continuing(), [&] { b.Store(i, b.Call(ty.i32(), fn_n, b.Load(i))); });
+ b.With(loop->Continuing(), [&] {
+ b.Store(i, b.Call(ty.i32(), fn_n, b.Load(i)));
+ b.NextIteration(loop);
+ });
});
+
+ b.Return(fn_f);
});
EXPECT_WGSL(R"(
@@ -2266,7 +2354,11 @@
auto* cond = b.If(true);
b.With(cond->True(), [&] { b.ExitIf(cond); });
b.With(cond->False(), [&] { b.ExitLoop(loop); });
+
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2289,7 +2381,11 @@
auto* if_ = b.If(cond);
b.With(if_->True(), [&] { b.ExitIf(if_); });
b.With(if_->False(), [&] { b.ExitLoop(loop); });
+
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2310,8 +2406,11 @@
auto* cond = b.If(true);
b.With(cond->True(), [&] { b.ExitIf(cond); });
b.With(cond->False(), [&] { b.ExitLoop(loop); });
+
b.ExitLoop(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2338,7 +2437,11 @@
auto* if2 = b.If(cond);
b.With(if2->True(), [&] { b.ExitLoop(loop); });
+
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2367,7 +2470,11 @@
auto* if2 = b.If(cond);
b.With(if2->True(), [&] { b.Return(fn); });
+
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2391,6 +2498,8 @@
auto* loop = b.Loop();
b.With(loop->Body(), [&] { b.ExitLoop(loop); });
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2413,7 +2522,10 @@
b.With(loop->Body(), [&] {
auto* if_ = b.If(cond);
b.With(if_->True(), [&] { b.ExitLoop(loop); });
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
fn f(cond : bool) {
@@ -2437,7 +2549,10 @@
b.With(loop->Body(), [&] {
auto* if_ = b.If(cond);
b.With(if_->True(), [&] { b.Return(fn); });
+ b.Continue(loop);
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2461,11 +2576,16 @@
auto* loop = b.Loop();
b.With(loop->Body(), [&] {
- auto* if_ = b.If(cond);
+ auto* if_ = b.If(b.Load(cond));
b.With(if_->True(), [&] { b.Return(fn); });
+ b.Continue(loop);
+ });
+ b.With(loop->Continuing(), [&] {
+ b.Store(cond, true);
+ b.NextIteration(loop);
});
- b.With(loop->Continuing(), [&] { b.Store(cond, true); });
+ b.Return(fn);
});
EXPECT_WGSL(R"(
@@ -2485,9 +2605,9 @@
}
TEST_F(IRToProgramTest, Loop_VarsDeclaredOutsideAndInside) {
- auto* f = b.Function("f", ty.void_());
+ auto* fn = b.Function("f", ty.void_());
- b.With(f->Block(), [&] {
+ b.With(fn->Block(), [&] {
auto* var_b = b.Var("b", ty.ptr<function, i32>());
var_b->SetInitializer(b.Constant(1_i));
@@ -2500,15 +2620,19 @@
auto* body_load_a = b.Load(var_a);
auto* body_load_b = b.Load(var_b);
auto* if_ = b.If(b.Equal(ty.bool_(), body_load_a, body_load_b));
- b.With(if_->True(), [&] { b.Return(f); });
+ b.With(if_->True(), [&] { b.Return(fn); });
b.With(if_->False(), [&] { b.ExitIf(if_); });
+ b.Continue(loop);
b.With(loop->Continuing(), [&] {
auto* cont_load_a = b.Load(var_a);
auto* cont_load_b = b.Load(var_b);
b.Store(var_b, b.Add(ty.i32(), cont_load_a, cont_load_b));
+ b.NextIteration(loop);
});
});
+
+ b.Return(fn);
});
EXPECT_WGSL(R"(
diff --git a/src/tint/ir/transform/merge_return_test.cc b/src/tint/ir/transform/merge_return_test.cc
index 0991bb5..2e59642 100644
--- a/src/tint/ir/transform/merge_return_test.cc
+++ b/src/tint/ir/transform/merge_return_test.cc
@@ -100,7 +100,8 @@
auto* swtch = b.Switch(in);
b.With(b.Case(swtch, {Switch::CaseSelector{}}), [&] { b.ExitSwitch(swtch); });
- b.Loop();
+ auto* l = b.Loop();
+ b.With(l->Body(), [&] { b.ExitLoop(l); });
auto* ifelse = b.If(cond);
ifelse->SetResults(b.InstructionResult(ty.i32()));
@@ -118,14 +119,17 @@
exit_switch # switch_1
}
}
- loop [] { # loop_1
+ loop [b: %b3] { # loop_1
+ %b3 = block { # body
+ exit_loop # loop_1
+ }
}
- %3:i32 = if %4 [t: %b3, f: %b4] { # if_1
- %b3 = block { # true
+ %3:i32 = if %4 [t: %b4, f: %b5] { # if_1
+ %b4 = block { # true
%5:i32 = add %2, 1i
exit_if %5 # if_1
}
- %b4 = block { # false
+ %b5 = block { # false
%6:i32 = add %2, 2i
exit_if %6 # if_1
}
@@ -1506,7 +1510,8 @@
b.With(func->Block(), [&] {
auto* sw = b.Switch(cond);
b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}), [&] {
- auto* ifelse = b.If(cond);
+ auto* ifcond = b.Equal(ty.bool_(), cond, 1_i);
+ auto* ifelse = b.If(ifcond);
b.With(ifelse->True(), [&] { b.Return(func, 42_i); });
b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
@@ -1528,7 +1533,8 @@
%b2 = block {
switch %3 [c: (1i, %b3), c: (default, %b4)] { # switch_1
%b3 = block { # case
- if %3 [t: %b5, f: %b6] { # if_1
+ %4:bool = eq %3, 1i
+ if %4 [t: %b5, f: %b6] { # if_1
%b5 = block { # true
ret 42i
}
@@ -1560,7 +1566,8 @@
%continue_execution:ptr<function, bool, read_write> = var, true
switch %3 [c: (1i, %b3), c: (default, %b4)] { # switch_1
%b3 = block { # case
- if %3 [t: %b5, f: %b6] { # if_1
+ %6:bool = eq %3, 1i
+ if %6 [t: %b5, f: %b6] { # if_1
%b5 = block { # true
store %continue_execution, false
store %return_value, 42i
@@ -1570,8 +1577,8 @@
exit_if # if_1
}
}
- %6:bool = load %continue_execution
- if %6 [t: %b7] { # if_2
+ %7:bool = load %continue_execution
+ if %7 [t: %b7] { # if_2
%b7 = block { # true
store %1, 2i
exit_switch # switch_1
@@ -1583,15 +1590,15 @@
exit_switch # switch_1
}
}
- %7:bool = load %continue_execution
- if %7 [t: %b8] { # if_3
+ %8:bool = load %continue_execution
+ if %8 [t: %b8] { # if_3
%b8 = block { # true
store %return_value, 0i
exit_if # if_3
}
}
- %8:i32 = load %return_value
- ret %8
+ %9:i32 = load %return_value
+ ret %9
}
}
)";
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index a743786..e43a208 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -35,6 +35,7 @@
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
@@ -266,9 +267,9 @@
[&](Call* c) { CheckCall(c); }, //
[&](If* if_) { CheckIf(if_); }, //
[&](Load*) {}, //
- [&](Loop*) {}, //
+ [&](Loop* l) { CheckLoop(l); }, //
[&](Store*) {}, //
- [&](Switch*) {}, //
+ [&](Switch* s) { CheckSwitch(s); }, //
[&](Swizzle*) {}, //
[&](Terminator* b) { CheckTerminator(b); }, //
[&](Unary* u) { CheckUnary(u); }, //
@@ -405,6 +406,28 @@
if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
AddError(if_, If::kConditionOperandOffset, "if: condition must be a `bool` type");
}
+
+ CheckBlock(if_->True());
+ if (!if_->False()->IsEmpty()) {
+ CheckBlock(if_->False());
+ }
+ }
+
+ void CheckLoop(Loop* l) {
+ if (!l->Initializer()->IsEmpty()) {
+ CheckBlock(l->Initializer());
+ }
+ CheckBlock(l->Body());
+
+ if (!l->Continuing()->IsEmpty()) {
+ CheckBlock(l->Continuing());
+ }
+ }
+
+ void CheckSwitch(Switch* s) {
+ for (auto& cse : s->Cases()) {
+ CheckBlock(cse.block);
+ }
}
void CheckVar(Var* var) {
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 723c3fc..db8079f 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -515,6 +515,50 @@
)");
}
+TEST_F(IR_ValidateTest, If_EmptyFalse) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto* if_ = b.If(true);
+ if_->True()->Append(b.Return(f));
+
+ f->Block()->Append(if_);
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure().str();
+}
+
+TEST_F(IR_ValidateTest, If_EmptyTrue) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto* if_ = b.If(true);
+ if_->False()->Append(b.Return(f));
+
+ f->Block()->Append(if_);
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:7 error: block: does not end in a terminator instruction
+ %b2 = block { # true
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ if true [t: %b2, f: %b3] { # if_1
+ %b2 = block { # true
+ }
+ %b3 = block { # false
+ ret
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidateTest, If_ConditionIsBool) {
auto* f = b.Function("my_func", ty.void_());
@@ -589,6 +633,46 @@
)");
}
+TEST_F(IR_ValidateTest, Loop_OnlyBody) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto* l = b.Loop();
+ l->Body()->Append(b.ExitLoop(l));
+
+ auto sb = b.With(f->Block());
+ sb.Append(l);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure().str();
+}
+
+TEST_F(IR_ValidateTest, Loop_EmptyBody) {
+ auto* f = b.Function("my_func", ty.void_());
+
+ auto sb = b.With(f->Block());
+ sb.Append(b.Loop());
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:4:7 error: block: does not end in a terminator instruction
+ %b2 = block { # body
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ loop [b: %b2] { # loop_1
+ %b2 = block { # body
+ }
+ }
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidateTest, Var_RootBlock_NullResult) {
auto* v = mod.instructions.Create<ir::Var>(nullptr);
b.RootBlock()->Append(v);
diff --git a/src/tint/writer/spirv/binary_writer.cc b/src/tint/writer/spirv/binary_writer.cc
index 4952bdb..cd563c3 100644
--- a/src/tint/writer/spirv/binary_writer.cc
+++ b/src/tint/writer/spirv/binary_writer.cc
@@ -37,10 +37,10 @@
process_instruction(inst);
}
-void BinaryWriter::WriteHeader(uint32_t bound) {
+void BinaryWriter::WriteHeader(uint32_t bound, uint32_t version) {
out_.push_back(spv::MagicNumber);
out_.push_back(0x00010300); // Version 1.3
- out_.push_back(kGeneratorId);
+ out_.push_back(kGeneratorId | version);
out_.push_back(bound);
out_.push_back(0);
}
diff --git a/src/tint/writer/spirv/binary_writer.h b/src/tint/writer/spirv/binary_writer.h
index b9696b8..d9d6770 100644
--- a/src/tint/writer/spirv/binary_writer.h
+++ b/src/tint/writer/spirv/binary_writer.h
@@ -30,7 +30,8 @@
/// Writes the SPIR-V header.
/// @param bound the bound to output
- void WriteHeader(uint32_t bound);
+ /// @param version the generator version number
+ void WriteHeader(uint32_t bound, uint32_t version = 0);
/// Writes the given module data into a binary. Note, this does not emit the SPIR-V header. You
/// **must** call WriteHeader() before WriteModule() if you want the SPIR-V to be emitted.
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 1fd5db9..61613da 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -53,12 +53,19 @@
#include "src/tint/transform/manager.h"
#include "src/tint/type/array.h"
#include "src/tint/type/bool.h"
+#include "src/tint/type/depth_multisampled_texture.h"
+#include "src/tint/type/depth_texture.h"
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
#include "src/tint/type/matrix.h"
+#include "src/tint/type/multisampled_texture.h"
#include "src/tint/type/pointer.h"
+#include "src/tint/type/sampled_texture.h"
+#include "src/tint/type/sampler.h"
+#include "src/tint/type/storage_texture.h"
#include "src/tint/type/struct.h"
+#include "src/tint/type/texture.h"
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
@@ -71,6 +78,10 @@
namespace {
+using namespace tint::number_suffixes; // NOLINT
+
+constexpr uint32_t kGeneratorVersion = 1;
+
void Sanitize(ir::Module* module) {
transform::Manager manager;
transform::DataMap data;
@@ -89,6 +100,8 @@
SpvStorageClass StorageClass(builtin::AddressSpace addrspace) {
switch (addrspace) {
+ case builtin::AddressSpace::kHandle:
+ return SpvStorageClassUniformConstant;
case builtin::AddressSpace::kFunction:
return SpvStorageClassFunction;
case builtin::AddressSpace::kIn:
@@ -146,7 +159,7 @@
}
// Serialize the module into binary SPIR-V.
- writer_.WriteHeader(module_.IdBound());
+ writer_.WriteHeader(module_.IdBound(), kGeneratorVersion);
writer_.WriteModule(&module_);
return true;
@@ -194,7 +207,14 @@
}
uint32_t GeneratorImplIr::Constant(ir::Constant* constant) {
- return Constant(constant->Value());
+ auto id = Constant(constant->Value());
+
+ // Set the name for the SPIR-V result ID if provided in the module.
+ if (auto name = ir_->NameOf(constant)) {
+ module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
+ }
+
+ return id;
}
uint32_t GeneratorImplIr::Constant(const constant::Value* constant) {
@@ -267,6 +287,14 @@
});
}
+uint32_t GeneratorImplIr::Undef(const type::Type* type) {
+ return undef_values_.GetOrCreate(type, [&] {
+ auto id = module_.NextId();
+ module_.PushType(spv::Op::OpUndef, {Type(type), id});
+ return id;
+ });
+}
+
uint32_t GeneratorImplIr::Type(const type::Type* ty,
builtin::AddressSpace addrspace /* = kUndefined */) {
return types_.GetOrCreate(ty, [&] {
@@ -316,6 +344,18 @@
Type(ptr->StoreType(), ptr->AddressSpace())});
},
[&](const type::Struct* str) { EmitStructType(id, str, addrspace); },
+ [&](const type::Texture* tex) { EmitTextureType(id, tex); },
+ [&](const type::Sampler* s) {
+ module_.PushType(spv::Op::OpTypeSampler, {id});
+
+ // Register both of the sampler types, as they're the same in SPIR-V.
+ if (s->kind() == type::SamplerKind::kSampler) {
+ types_.Add(
+ ir_->Types().Get<type::Sampler>(type::SamplerKind::kComparisonSampler), id);
+ } else {
+ types_.Add(ir_->Types().Get<type::Sampler>(type::SamplerKind::kSampler), id);
+ }
+ },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled type: " << ty->FriendlyName();
});
@@ -436,6 +476,82 @@
}
}
+void GeneratorImplIr::EmitTextureType(uint32_t id, const type::Texture* texture) {
+ uint32_t sampled_type = Switch(
+ texture, //
+ [&](const type::DepthTexture*) { return Type(ir_->Types().f32()); },
+ [&](const type::DepthMultisampledTexture*) { return Type(ir_->Types().f32()); },
+ [&](const type::SampledTexture* t) { return Type(t->type()); },
+ [&](const type::MultisampledTexture* t) { return Type(t->type()); },
+ [&](const type::StorageTexture* t) { return Type(t->type()); });
+
+ uint32_t dim = SpvDimMax;
+ uint32_t array = 0u;
+ switch (texture->dim()) {
+ case type::TextureDimension::kNone: {
+ break;
+ }
+ case type::TextureDimension::k1d: {
+ dim = SpvDim1D;
+ if (texture->Is<type::SampledTexture>()) {
+ module_.PushCapability(SpvCapabilitySampled1D);
+ } else if (texture->Is<type::StorageTexture>()) {
+ module_.PushCapability(SpvCapabilityImage1D);
+ }
+ break;
+ }
+ case type::TextureDimension::k2d: {
+ dim = SpvDim2D;
+ break;
+ }
+ case type::TextureDimension::k2dArray: {
+ dim = SpvDim2D;
+ array = 1u;
+ break;
+ }
+ case type::TextureDimension::k3d: {
+ dim = SpvDim3D;
+ break;
+ }
+ case type::TextureDimension::kCube: {
+ dim = SpvDimCube;
+ break;
+ }
+ case type::TextureDimension::kCubeArray: {
+ dim = SpvDimCube;
+ array = 1u;
+ if (texture->IsAnyOf<type::SampledTexture, type::DepthTexture>()) {
+ module_.PushCapability(SpvCapabilitySampledCubeArray);
+ }
+ break;
+ }
+ }
+
+ // The Vulkan spec says: The "Depth" operand of OpTypeImage is ignored.
+ // In SPIRV, 0 means not depth, 1 means depth, and 2 means unknown.
+ // Using anything other than 0 is problematic on various Vulkan drivers.
+ uint32_t depth = 0u;
+
+ uint32_t ms = 0u;
+ if (texture->IsAnyOf<type::MultisampledTexture, type::DepthMultisampledTexture>()) {
+ ms = 1u;
+ }
+
+ uint32_t sampled = 2u;
+ if (texture->IsAnyOf<type::MultisampledTexture, type::SampledTexture, type::DepthTexture,
+ type::DepthMultisampledTexture>()) {
+ sampled = 1u;
+ }
+
+ uint32_t format = SpvImageFormat_::SpvImageFormatUnknown;
+ if (auto* st = texture->As<type::StorageTexture>()) {
+ format = TexelFormat(st->texel_format());
+ }
+
+ module_.PushType(spv::Op::OpTypeImage,
+ {id, sampled_type, dim, depth, array, ms, sampled, format});
+}
+
void GeneratorImplIr::EmitFunction(ir::Function* func) {
auto id = Value(func);
@@ -622,13 +738,17 @@
inst, //
[&](ir::Access* a) { EmitAccess(a); }, //
[&](ir::Binary* b) { EmitBinary(b); }, //
+ [&](ir::Bitcast* b) { EmitBitcast(b); }, //
[&](ir::BuiltinCall* b) { EmitBuiltinCall(b); }, //
[&](ir::Construct* c) { EmitConstruct(c); }, //
+ [&](ir::Convert* c) { EmitConvert(c); }, //
[&](ir::Load* l) { EmitLoad(l); }, //
[&](ir::Loop* l) { EmitLoop(l); }, //
[&](ir::Switch* sw) { EmitSwitch(sw); }, //
+ [&](ir::Swizzle* s) { EmitSwizzle(s); }, //
[&](ir::Store* s) { EmitStore(s); }, //
[&](ir::UserCall* c) { EmitUserCall(c); }, //
+ [&](ir::Unary* u) { EmitUnary(u); }, //
[&](ir::Var* v) { EmitVar(v); }, //
[&](ir::If* i) { EmitIf(i); }, //
[&](ir::Terminator* t) { EmitTerminator(t); }, //
@@ -670,7 +790,7 @@
{
Value(breakif->Condition()),
loop_merge_label_,
- Label(breakif->Loop()->Body()),
+ loop_header_label_,
});
},
[&](ir::Continue* cont) {
@@ -681,8 +801,8 @@
[&](ir::ExitSwitch*) {
current_function_.push_inst(spv::Op::OpBranch, {switch_merge_label_});
},
- [&](ir::NextIteration* loop) {
- current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Body())});
+ [&](ir::NextIteration*) {
+ current_function_.push_inst(spv::Op::OpBranch, {loop_header_label_});
},
[&](ir::Unreachable*) { current_function_.push_inst(spv::Op::OpUnreachable, {}); },
@@ -784,8 +904,11 @@
void GeneratorImplIr::EmitBinary(ir::Binary* binary) {
auto id = Value(binary);
+ auto lhs = Value(binary->LHS());
+ auto rhs = Value(binary->RHS());
auto* ty = binary->Result()->Type();
auto* lhs_ty = binary->LHS()->Type();
+ auto* rhs_ty = binary->RHS()->Type();
// Determine the opcode.
spv::Op op = spv::Op::Max;
@@ -794,10 +917,66 @@
op = ty->is_integer_scalar_or_vector() ? spv::Op::OpIAdd : spv::Op::OpFAdd;
break;
}
+ case ir::Binary::Kind::kDivide: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSDiv;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUDiv;
+ } else if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFDiv;
+ }
+ break;
+ }
+ case ir::Binary::Kind::kMultiply: {
+ if (ty->is_integer_scalar_or_vector()) {
+ // If the result is an integer then we can only use OpIMul.
+ op = spv::Op::OpIMul;
+ } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_scalar()) {
+ // Two float scalars multiply with OpFMul.
+ op = spv::Op::OpFMul;
+ } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_vector()) {
+ // Two float vectors multiply with OpFMul.
+ op = spv::Op::OpFMul;
+ } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_vector()) {
+ // Use OpVectorTimesScalar for scalar * vector, and swap the operand order.
+ std::swap(lhs, rhs);
+ op = spv::Op::OpVectorTimesScalar;
+ } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_scalar()) {
+ // Use OpVectorTimesScalar for scalar * vector.
+ op = spv::Op::OpVectorTimesScalar;
+ } else if (lhs_ty->is_float_scalar() && rhs_ty->is_float_matrix()) {
+ // Use OpMatrixTimesScalar for scalar * matrix, and swap the operand order.
+ std::swap(lhs, rhs);
+ op = spv::Op::OpMatrixTimesScalar;
+ } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_scalar()) {
+ // Use OpMatrixTimesScalar for scalar * matrix.
+ op = spv::Op::OpMatrixTimesScalar;
+ } else if (lhs_ty->is_float_vector() && rhs_ty->is_float_matrix()) {
+ // Use OpVectorTimesMatrix for vector * matrix.
+ op = spv::Op::OpVectorTimesMatrix;
+ } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_vector()) {
+ // Use OpMatrixTimesVector for matrix * vector.
+ op = spv::Op::OpMatrixTimesVector;
+ } else if (lhs_ty->is_float_matrix() && rhs_ty->is_float_matrix()) {
+ // Use OpMatrixTimesMatrix for matrix * vector.
+ op = spv::Op::OpMatrixTimesMatrix;
+ }
+ break;
+ }
case ir::Binary::Kind::kSubtract: {
op = ty->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub;
break;
}
+ case ir::Binary::Kind::kModulo: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSRem;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpUMod;
+ } else if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFRem;
+ }
+ break;
+ }
case ir::Binary::Kind::kAnd: {
op = spv::Op::OpBitwiseAnd;
@@ -812,6 +991,19 @@
break;
}
+ case ir::Binary::Kind::kShiftLeft: {
+ op = spv::Op::OpShiftLeftLogical;
+ break;
+ }
+ case ir::Binary::Kind::kShiftRight: {
+ if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpShiftRightArithmetic;
+ } else if (ty->is_unsigned_integer_scalar_or_vector()) {
+ op = spv::Op::OpShiftRightLogical;
+ }
+ break;
+ }
+
case ir::Binary::Kind::kEqual: {
if (lhs_ty->is_bool_scalar_or_vector()) {
op = spv::Op::OpLogicalEqual;
@@ -872,15 +1064,20 @@
}
break;
}
-
- default: {
- TINT_ICE(Writer, diagnostics_)
- << "unimplemented binary instruction: " << static_cast<uint32_t>(binary->Kind());
- }
}
// Emit the instruction.
- current_function_.push_inst(op, {Type(ty), id, Value(binary->LHS()), Value(binary->RHS())});
+ current_function_.push_inst(op, {Type(ty), id, lhs, rhs});
+}
+
+void GeneratorImplIr::EmitBitcast(ir::Bitcast* bitcast) {
+ auto* ty = bitcast->Result()->Type();
+ if (ty == bitcast->Val()->Type()) {
+ values_.Add(bitcast->Result(), Value(bitcast->Val()));
+ return;
+ }
+ current_function_.push_inst(spv::Op::OpBitcast,
+ {Type(ty), Value(bitcast), Value(bitcast->Val())});
}
void GeneratorImplIr::EmitBuiltinCall(ir::BuiltinCall* builtin) {
@@ -892,6 +1089,12 @@
values_.Add(builtin->Result(), Value(builtin->Args()[0]));
return;
}
+ if (builtin->Func() == builtin::Function::kAny &&
+ builtin->Args()[0]->Type()->Is<type::Bool>()) {
+ // any() is a passthrough for a scalar argument.
+ values_.Add(builtin->Result(), Value(builtin->Args()[0]));
+ return;
+ }
auto id = Value(builtin);
@@ -920,6 +1123,78 @@
glsl_ext_inst(GLSLstd450SAbs);
}
break;
+ case builtin::Function::kAny:
+ op = spv::Op::OpAny;
+ break;
+ case builtin::Function::kAcos:
+ glsl_ext_inst(GLSLstd450Acos);
+ break;
+ case builtin::Function::kAcosh:
+ glsl_ext_inst(GLSLstd450Acosh);
+ break;
+ case builtin::Function::kAsin:
+ glsl_ext_inst(GLSLstd450Asin);
+ break;
+ case builtin::Function::kAsinh:
+ glsl_ext_inst(GLSLstd450Asinh);
+ break;
+ case builtin::Function::kAtan:
+ glsl_ext_inst(GLSLstd450Atan);
+ break;
+ case builtin::Function::kAtan2:
+ glsl_ext_inst(GLSLstd450Atan2);
+ break;
+ case builtin::Function::kAtanh:
+ glsl_ext_inst(GLSLstd450Atanh);
+ break;
+ case builtin::Function::kClamp:
+ if (result_ty->is_float_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450NClamp);
+ } else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450UClamp);
+ } else if (result_ty->is_signed_integer_scalar_or_vector()) {
+ glsl_ext_inst(GLSLstd450SClamp);
+ }
+ break;
+ case builtin::Function::kCos:
+ glsl_ext_inst(GLSLstd450Cos);
+ break;
+ case builtin::Function::kCosh:
+ glsl_ext_inst(GLSLstd450Cosh);
+ break;
+ case builtin::Function::kCross:
+ glsl_ext_inst(GLSLstd450Cross);
+ break;
+ case builtin::Function::kDistance:
+ glsl_ext_inst(GLSLstd450Distance);
+ break;
+ case builtin::Function::kDpdx:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdx;
+ break;
+ case builtin::Function::kDpdxCoarse:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdxCoarse;
+ break;
+ case builtin::Function::kDpdxFine:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdxFine;
+ break;
+ case builtin::Function::kDpdy:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdy;
+ break;
+ case builtin::Function::kDpdyCoarse:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdyCoarse;
+ break;
+ case builtin::Function::kDpdyFine:
+ module_.PushCapability(SpvCapabilityDerivativeControl);
+ op = spv::Op::OpDPdyFine;
+ break;
+ case builtin::Function::kLength:
+ glsl_ext_inst(GLSLstd450Length);
+ break;
case builtin::Function::kMax:
if (result_ty->is_float_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450FMax);
@@ -938,6 +1213,21 @@
glsl_ext_inst(GLSLstd450UMin);
}
break;
+ case builtin::Function::kNormalize:
+ glsl_ext_inst(GLSLstd450Normalize);
+ break;
+ case builtin::Function::kSin:
+ glsl_ext_inst(GLSLstd450Sin);
+ break;
+ case builtin::Function::kSinh:
+ glsl_ext_inst(GLSLstd450Sinh);
+ break;
+ case builtin::Function::kTan:
+ glsl_ext_inst(GLSLstd450Tan);
+ break;
+ case builtin::Function::kTanh:
+ glsl_ext_inst(GLSLstd450Tanh);
+ break;
default:
TINT_ICE(Writer, diagnostics_) << "unimplemented builtin function: " << builtin->Func();
}
@@ -960,6 +1250,88 @@
current_function_.push_inst(spv::Op::OpCompositeConstruct, std::move(operands));
}
+void GeneratorImplIr::EmitConvert(ir::Convert* convert) {
+ auto* res_ty = convert->Result()->Type();
+ auto* arg_ty = convert->Args()[0]->Type();
+
+ OperandList operands = {Type(convert->Result()->Type()), Value(convert)};
+ for (auto* arg : convert->Args()) {
+ operands.push_back(Value(arg));
+ }
+
+ spv::Op op = spv::Op::Max;
+ if (res_ty->is_signed_integer_scalar_or_vector() && arg_ty->is_float_scalar_or_vector()) {
+ // float to signed int.
+ op = spv::Op::OpConvertFToS;
+ } else if (res_ty->is_unsigned_integer_scalar_or_vector() &&
+ arg_ty->is_float_scalar_or_vector()) {
+ // float to unsigned int.
+ op = spv::Op::OpConvertFToU;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_signed_integer_scalar_or_vector()) {
+ // signed int to float.
+ op = spv::Op::OpConvertSToF;
+ } else if (res_ty->is_float_scalar_or_vector() &&
+ arg_ty->is_unsigned_integer_scalar_or_vector()) {
+ // unsigned int to float.
+ op = spv::Op::OpConvertUToF;
+ } else if (res_ty->is_float_scalar_or_vector() && arg_ty->is_float_scalar_or_vector() &&
+ res_ty->Size() != arg_ty->Size()) {
+ // float to float (different bitwidth).
+ op = spv::Op::OpFConvert;
+ } else if (res_ty->is_integer_scalar_or_vector() && arg_ty->is_integer_scalar_or_vector() &&
+ res_ty->Size() == arg_ty->Size()) {
+ // int to int (same bitwidth, different signedness).
+ op = spv::Op::OpBitcast;
+ } else if (res_ty->is_bool_scalar_or_vector()) {
+ if (arg_ty->is_integer_scalar_or_vector()) {
+ // int to bool.
+ op = spv::Op::OpINotEqual;
+ } else {
+ // float to bool.
+ op = spv::Op::OpFUnordNotEqual;
+ }
+ operands.push_back(ConstantNull(arg_ty));
+ } else if (arg_ty->is_bool_scalar_or_vector()) {
+ // Select between constant one and zero, splatting them to vectors if necessary.
+ const constant::Value* one = nullptr;
+ const constant::Value* zero = nullptr;
+ Switch(
+ res_ty->DeepestElement(), //
+ [&](const type::F32*) {
+ one = ir_->constant_values.Get(1_f);
+ zero = ir_->constant_values.Get(0_f);
+ },
+ [&](const type::F16*) {
+ one = ir_->constant_values.Get(1_h);
+ zero = ir_->constant_values.Get(0_h);
+ },
+ [&](const type::I32*) {
+ one = ir_->constant_values.Get(1_i);
+ zero = ir_->constant_values.Get(0_i);
+ },
+ [&](const type::U32*) {
+ one = ir_->constant_values.Get(1_u);
+ zero = ir_->constant_values.Get(0_u);
+ });
+ TINT_ASSERT_OR_RETURN(Writer, one && zero);
+
+ if (auto* vec = res_ty->As<type::Vector>()) {
+ // Splat the scalars into vectors.
+ one = ir_->constant_values.Splat(vec, one, vec->Width());
+ zero = ir_->constant_values.Splat(vec, zero, vec->Width());
+ }
+
+ op = spv::Op::OpSelect;
+ operands.push_back(Constant(one));
+ operands.push_back(Constant(zero));
+ } else {
+ TINT_ICE(Writer, diagnostics_) << "unhandled convert instruction";
+ }
+
+ current_function_.push_inst(op, std::move(operands));
+}
+
void GeneratorImplIr::EmitLoad(ir::Load* load) {
current_function_.push_inst(spv::Op::OpLoad,
{Type(load->Result()->Type()), Value(load), Value(load->From())});
@@ -967,11 +1339,13 @@
void GeneratorImplIr::EmitLoop(ir::Loop* loop) {
auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
- auto header_label = Label(loop->Body()); // Back-edge needs to branch to the loop header
- auto body_label = module_.NextId();
+ auto body_label = Label(loop->Body());
auto continuing_label = Label(loop->Continuing());
- uint32_t merge_label = module_.NextId();
+ auto header_label = module_.NextId();
+ TINT_SCOPED_ASSIGNMENT(loop_header_label_, header_label);
+
+ auto merge_label = module_.NextId();
TINT_SCOPED_ASSIGNMENT(loop_merge_label_, merge_label);
if (init_label != 0) {
@@ -1056,10 +1430,39 @@
EmitExitPhis(swtch);
}
+void GeneratorImplIr::EmitSwizzle(ir::Swizzle* swizzle) {
+ auto id = Value(swizzle);
+ auto obj = Value(swizzle->Object());
+ OperandList operands = {Type(swizzle->Result()->Type()), id, obj, obj};
+ for (auto idx : swizzle->Indices()) {
+ operands.push_back(idx);
+ }
+ current_function_.push_inst(spv::Op::OpVectorShuffle, operands);
+}
+
void GeneratorImplIr::EmitStore(ir::Store* store) {
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
}
+void GeneratorImplIr::EmitUnary(ir::Unary* unary) {
+ auto id = Value(unary);
+ auto* ty = unary->Result()->Type();
+ spv::Op op = spv::Op::Max;
+ switch (unary->Kind()) {
+ case ir::Unary::Kind::kComplement:
+ op = spv::Op::OpNot;
+ break;
+ case ir::Unary::Kind::kNegation:
+ if (ty->is_float_scalar_or_vector()) {
+ op = spv::Op::OpFNegate;
+ } else if (ty->is_signed_integer_scalar_or_vector()) {
+ op = spv::Op::OpSNegate;
+ }
+ break;
+ }
+ current_function_.push_inst(op, {Type(ty), id, Value(unary->Val())});
+}
+
void GeneratorImplIr::EmitUserCall(ir::UserCall* call) {
auto id = Value(call);
OperandList operands = {Type(call->Result()->Type()), id, Value(call->Func())};
@@ -1103,6 +1506,7 @@
module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassOutput)});
break;
}
+ case builtin::AddressSpace::kHandle:
case builtin::AddressSpace::kStorage:
case builtin::AddressSpace::kUniform: {
TINT_ASSERT(Writer, !current_function_);
@@ -1159,11 +1563,62 @@
OperandList ops{Type(ty), Value(result)};
for (auto& branch : branches) {
- ops.push_back(Value(branch.value));
+ if (branch.value == nullptr) {
+ ops.push_back(Undef(ty));
+ } else {
+ ops.push_back(Value(branch.value));
+ }
ops.push_back(branch.label);
}
current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
}
}
+uint32_t GeneratorImplIr::TexelFormat(const builtin::TexelFormat format) {
+ switch (format) {
+ case builtin::TexelFormat::kBgra8Unorm:
+ TINT_ICE(Writer, diagnostics_)
+ << "bgra8unorm should have been polyfilled to rgba8unorm";
+ return SpvImageFormatUnknown;
+ case builtin::TexelFormat::kR32Uint:
+ return SpvImageFormatR32ui;
+ case builtin::TexelFormat::kR32Sint:
+ return SpvImageFormatR32i;
+ case builtin::TexelFormat::kR32Float:
+ return SpvImageFormatR32f;
+ case builtin::TexelFormat::kRgba8Unorm:
+ return SpvImageFormatRgba8;
+ case builtin::TexelFormat::kRgba8Snorm:
+ return SpvImageFormatRgba8Snorm;
+ case builtin::TexelFormat::kRgba8Uint:
+ return SpvImageFormatRgba8ui;
+ case builtin::TexelFormat::kRgba8Sint:
+ return SpvImageFormatRgba8i;
+ case builtin::TexelFormat::kRg32Uint:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32ui;
+ case builtin::TexelFormat::kRg32Sint:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32i;
+ case builtin::TexelFormat::kRg32Float:
+ module_.PushCapability(SpvCapabilityStorageImageExtendedFormats);
+ return SpvImageFormatRg32f;
+ case builtin::TexelFormat::kRgba16Uint:
+ return SpvImageFormatRgba16ui;
+ case builtin::TexelFormat::kRgba16Sint:
+ return SpvImageFormatRgba16i;
+ case builtin::TexelFormat::kRgba16Float:
+ return SpvImageFormatRgba16f;
+ case builtin::TexelFormat::kRgba32Uint:
+ return SpvImageFormatRgba32ui;
+ case builtin::TexelFormat::kRgba32Sint:
+ return SpvImageFormatRgba32i;
+ case builtin::TexelFormat::kRgba32Float:
+ return SpvImageFormatRgba32f;
+ case builtin::TexelFormat::kUndefined:
+ return SpvImageFormatUnknown;
+ }
+ return SpvImageFormatUnknown;
+}
+
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index ae1bc95..0ca82ad 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -31,11 +31,13 @@
namespace tint::ir {
class Access;
class Binary;
+class Bitcast;
class Block;
class BlockParam;
class BuiltinCall;
class Construct;
class ControlInstruction;
+class Convert;
class ExitIf;
class ExitLoop;
class ExitSwitch;
@@ -47,13 +49,16 @@
class MultiInBlock;
class Store;
class Switch;
+class Swizzle;
class Terminator;
+class Unary;
class UserCall;
class Value;
class Var;
} // namespace tint::ir
namespace tint::type {
class Struct;
+class Texture;
class Type;
} // namespace tint::type
@@ -85,11 +90,6 @@
/// @returns the result ID of the constant
uint32_t Constant(ir::Constant* constant);
- /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
- /// @param type the type to get the ID for
- /// @returns the result ID of the OpConstantNull instruction
- uint32_t ConstantNull(const type::Type* type);
-
/// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
/// @param ty the type to get the ID for
/// @param addrspace the optional address space that this type is being used for
@@ -97,6 +97,34 @@
uint32_t Type(const type::Type* ty,
builtin::AddressSpace addrspace = builtin::AddressSpace::kUndefined);
+ private:
+ /// Convert a builtin to the corresponding SPIR-V enum value, taking into account the target
+ /// address space. Adds any capabilities needed for the builtin.
+ /// @param builtin the builtin to convert
+ /// @param addrspace the address space the builtin is being used in
+ /// @returns the enum value of the corresponding SPIR-V builtin
+ uint32_t Builtin(builtin::BuiltinValue builtin, builtin::AddressSpace addrspace);
+
+ /// Convert a texel format to the corresponding SPIR-V enum value, adding required capabilities.
+ /// @param format the format to convert
+ /// @returns the enum value of the corresponding SPIR-V texel format
+ uint32_t TexelFormat(const builtin::TexelFormat format);
+
+ /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
+ /// @param constant the constant to get the ID for
+ /// @returns the result ID of the constant
+ uint32_t Constant(const constant::Value* constant);
+
+ /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
+ /// @param type the type to get the ID for
+ /// @returns the result ID of the OpConstantNull instruction
+ uint32_t ConstantNull(const type::Type* type);
+
+ /// Get the ID of the label for `block`.
+ /// @param block the block to get the label ID for
+ /// @returns the ID of the block's label
+ uint32_t Label(ir::Block* block);
+
/// Get the result ID of the value `value`, emitting its instruction if necessary.
/// @param value the value to get the ID for
/// @returns the result ID of the value
@@ -107,10 +135,10 @@
/// @returns the result ID of the instruction
uint32_t Value(ir::Instruction* inst);
- /// Get the ID of the label for `block`.
- /// @param block the block to get the label ID for
- /// @returns the ID of the block's label
- uint32_t Label(ir::Block* block);
+ /// Get the result ID of the OpUndef instruction with type `ty`, emitting it if necessary.
+ /// @param ty the type of the undef value
+ /// @returns the result ID of the instruction
+ uint32_t Undef(const type::Type* ty);
/// Emit a struct type.
/// @param id the result ID to use
@@ -120,6 +148,11 @@
const type::Struct* str,
builtin::AddressSpace addrspace = builtin::AddressSpace::kUndefined);
+ /// Emit a texture type.
+ /// @param id the result ID to use
+ /// @param texture the texture type to emit
+ void EmitTextureType(uint32_t id, const type::Texture* texture);
+
/// Emit a function.
/// @param func the function to emit
void EmitFunction(ir::Function* func);
@@ -157,6 +190,10 @@
/// @param binary the binary instruction to emit
void EmitBinary(ir::Binary* binary);
+ /// Emit a bitcast instruction.
+ /// @param bitcast the bitcast instruction to emit
+ void EmitBitcast(ir::Bitcast* bitcast);
+
/// Emit a builtin function call instruction.
/// @param call the builtin call instruction to emit
void EmitBuiltinCall(ir::BuiltinCall* call);
@@ -165,6 +202,10 @@
/// @param construct the construct instruction to emit
void EmitConstruct(ir::Construct* construct);
+ /// Emit a convert instruction.
+ /// @param convert the convert instruction to emit
+ void EmitConvert(ir::Convert* convert);
+
/// Emit a load instruction.
/// @param load the load instruction to emit
void EmitLoad(ir::Load* load);
@@ -181,6 +222,14 @@
/// @param swtch the switch instruction to emit
void EmitSwitch(ir::Switch* swtch);
+ /// Emit a swizzle instruction.
+ /// @param swizzle the swizzle instruction to emit
+ void EmitSwizzle(ir::Swizzle* swizzle);
+
+ /// Emit a unary instruction.
+ /// @param unary the unary instruction to emit
+ void EmitUnary(ir::Unary* unary);
+
/// Emit a user call instruction.
/// @param call the user call instruction to emit
void EmitUserCall(ir::UserCall* call);
@@ -197,19 +246,6 @@
/// @param inst the flow control instruction
void EmitExitPhis(ir::ControlInstruction* inst);
- private:
- /// Convert a builtin to the corresponding SPIR-V enum value, taking into account the target
- /// address space. Adds any capabilities needed for the builtin.
- /// @param builtin the builtin to convert
- /// @param addrspace the address space the builtin is being used in
- /// @returns the enum value of the corresponding SPIR-V builtin
- uint32_t Builtin(builtin::BuiltinValue builtin, builtin::AddressSpace addrspace);
-
- /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
- /// @param constant the constant to get the ID for
- /// @returns the result ID of the constant
- uint32_t Constant(const constant::Value* constant);
-
ir::Module* ir_;
spirv::Module module_;
BinaryWriter writer_;
@@ -252,6 +288,9 @@
/// The map of types to the result IDs of their OpConstantNull instructions.
utils::Hashmap<const type::Type*, uint32_t, 4> constant_nulls_;
+ /// The map of types to the result IDs of their OpUndef instructions.
+ utils::Hashmap<const type::Type*, uint32_t, 4> undef_values_;
+
/// The map of non-constant values to their result IDs.
utils::Hashmap<ir::Value*, uint32_t, 8> values_;
@@ -267,6 +306,9 @@
/// The merge block for the current if statement
uint32_t if_merge_label_ = 0;
+ /// The header block for the current loop statement
+ uint32_t loop_header_label_ = 0;
+
/// The merge block for the current loop statement
uint32_t loop_merge_label_ = 0;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
index df494fb..86d09f7 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "gmock/gmock.h"
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
@@ -31,7 +30,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpCompositeExtract %int %arr 1");
}
@@ -44,7 +43,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %arr %uint_1");
}
@@ -59,7 +58,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %arr %idx");
}
@@ -75,7 +74,7 @@
mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_vector = OpCompositeExtract %v2float %mat 1");
EXPECT_INST("%result_scalar = OpCompositeExtract %float %mat 1 0");
}
@@ -91,7 +90,7 @@
mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %uint_1");
EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %uint_1 %uint_0");
}
@@ -109,7 +108,7 @@
mod.SetName(result_scalar, "result_scalar");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_vector = OpAccessChain %_ptr_Function_v2float %mat %idx");
EXPECT_INST("%result_scalar = OpAccessChain %_ptr_Function_float %mat %idx %idx");
}
@@ -124,7 +123,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpCompositeExtract %int %vec 1");
}
@@ -139,7 +138,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpVectorExtractDynamic %int %vec %idx");
}
@@ -152,7 +151,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %uint_1");
}
@@ -167,7 +166,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result = OpAccessChain %_ptr_Function_int %vec %idx");
}
@@ -182,7 +181,7 @@
mod.SetName(result, "result");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%14 = OpCompositeExtract %v4int %arr 1 2");
EXPECT_INST("%result = OpVectorExtractDynamic %int %14 %idx");
}
@@ -204,7 +203,7 @@
mod.SetName(result_b, "result_b");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_a = OpCompositeExtract %float %str 0");
EXPECT_INST("%result_b = OpCompositeExtract %int %str 1 2");
}
@@ -225,7 +224,7 @@
mod.SetName(result_b, "result_b");
});
- ASSERT_TRUE(Generate()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
EXPECT_INST("%result_a = OpAccessChain %_ptr_Function_float %str %uint_0");
EXPECT_INST("%result_b = OpAccessChain %_ptr_Function_int %str %uint_1 %uint_2");
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 6d4d890..bdd542c 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -14,7 +14,6 @@
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
-#include "gmock/gmock.h"
#include "src/tint/ir/binary.h"
using namespace tint::number_suffixes; // NOLINT
@@ -30,103 +29,190 @@
enum ir::Binary::Kind kind;
/// The expected SPIR-V instruction.
std::string spirv_inst;
+ /// The expected SPIR-V result type name.
+ std::string spirv_type_name;
};
-using Arithmetic = SpvGeneratorImplTestWithParam<BinaryTestCase>;
-TEST_P(Arithmetic, Scalar) {
+using Arithmetic_Bitwise = SpvGeneratorImplTestWithParam<BinaryTestCase>;
+TEST_P(Arithmetic_Bitwise, Scalar) {
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- b.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
- MakeScalarValue(params.type));
+ auto* lhs = MakeScalarValue(params.type);
+ auto* rhs = MakeScalarValue(params.type);
+ auto* result = b.Binary(params.kind, MakeScalarType(params.type), lhs, rhs);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %" + params.spirv_type_name);
}
-TEST_P(Arithmetic, Vector) {
+TEST_P(Arithmetic_Bitwise, Vector) {
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- b.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
+ auto* lhs = MakeVectorValue(params.type);
+ auto* rhs = MakeVectorValue(params.type);
+ auto* result = b.Binary(params.kind, MakeVectorType(params.type), lhs, rhs);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
-}
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_I32,
- Arithmetic,
- testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAdd, "OpIAdd"},
- BinaryTestCase{kI32, ir::Binary::Kind::kSubtract,
- "OpISub"}));
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_U32,
- Arithmetic,
- testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kAdd, "OpIAdd"},
- BinaryTestCase{kU32, ir::Binary::Kind::kSubtract,
- "OpISub"}));
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F32,
- Arithmetic,
- testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kAdd, "OpFAdd"},
- BinaryTestCase{kF32, ir::Binary::Kind::kSubtract,
- "OpFSub"}));
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16,
- Arithmetic,
- testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kAdd, "OpFAdd"},
- BinaryTestCase{kF16, ir::Binary::Kind::kSubtract,
- "OpFSub"}));
-
-using Bitwise = SpvGeneratorImplTestWithParam<BinaryTestCase>;
-TEST_P(Bitwise, Scalar) {
- auto params = GetParam();
-
- auto* func = b.Function("foo", ty.void_());
- b.With(func->Block(), [&] {
- b.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
- MakeScalarValue(params.type));
- b.Return(func);
- });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
-}
-TEST_P(Bitwise, Vector) {
- auto params = GetParam();
-
- auto* func = b.Function("foo", ty.void_());
- b.With(func->Block(), [&] {
- b.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
- b.Return(func);
- });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %v2" + params.spirv_type_name);
}
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_I32,
- Bitwise,
- testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAnd, "OpBitwiseAnd"},
- BinaryTestCase{kI32, ir::Binary::Kind::kOr, "OpBitwiseOr"},
- BinaryTestCase{kI32, ir::Binary::Kind::kXor, "OpBitwiseXor"}));
+ Arithmetic_Bitwise,
+ testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAdd, "OpIAdd", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kSubtract, "OpISub", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kMultiply, "OpIMul", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kDivide, "OpSDiv", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kModulo, "OpSRem", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kAnd, "OpBitwiseAnd", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kOr, "OpBitwiseOr", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kXor, "OpBitwiseXor", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kShiftLeft, "OpShiftLeftLogical", "int"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kShiftRight, "OpShiftRightArithmetic",
+ "int"}));
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_U32,
- Bitwise,
- testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kAnd, "OpBitwiseAnd"},
- BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr"},
- BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor"}));
+ Arithmetic_Bitwise,
+ testing::Values(
+ BinaryTestCase{kU32, ir::Binary::Kind::kAdd, "OpIAdd", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kSubtract, "OpISub", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kMultiply, "OpIMul", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kDivide, "OpUDiv", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kModulo, "OpUMod", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kAnd, "OpBitwiseAnd", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kShiftLeft, "OpShiftLeftLogical", "uint"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kShiftRight, "OpShiftRightLogical", "uint"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F32,
+ Arithmetic_Bitwise,
+ testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kAdd, "OpFAdd", "float"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kSubtract, "OpFSub", "float"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kMultiply, "OpFMul", "float"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kDivide, "OpFDiv", "float"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kModulo, "OpFRem", "float"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Binary_F16,
+ Arithmetic_Bitwise,
+ testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kAdd, "OpFAdd", "half"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kSubtract, "OpFSub", "half"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kMultiply, "OpFMul", "half"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kDivide, "OpFDiv", "half"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kModulo, "OpFRem", "half"}));
+
+TEST_F(SpvGeneratorImplTest, Binary_ScalarTimesVector_F32) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({scalar, vector});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), scalar, vector);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorTimesScalar %v4float %vector %scalar");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_VectorTimesScalar_F32) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({scalar, vector});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), vector, scalar);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorTimesScalar %v4float %vector %scalar");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_ScalarTimesMatrix_F32) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* matrix = b.FunctionParam("matrix", ty.mat3x4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({scalar, matrix});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat3x4<f32>(), scalar, matrix);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpMatrixTimesScalar %mat3v4float %matrix %scalar");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_MatrixTimesScalar_F32) {
+ auto* scalar = b.FunctionParam("scalar", ty.f32());
+ auto* matrix = b.FunctionParam("matrix", ty.mat3x4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({scalar, matrix});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat3x4<f32>(), matrix, scalar);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpMatrixTimesScalar %mat3v4float %matrix %scalar");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_VectorTimesMatrix_F32) {
+ auto* vector = b.FunctionParam("vector", ty.vec4<f32>());
+ auto* matrix = b.FunctionParam("matrix", ty.mat3x4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vector, matrix});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec3<f32>(), vector, matrix);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorTimesMatrix %v3float %vector %matrix");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_MatrixTimesVector_F32) {
+ auto* vector = b.FunctionParam("vector", ty.vec3<f32>());
+ auto* matrix = b.FunctionParam("matrix", ty.mat3x4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vector, matrix});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.vec4<f32>(), matrix, vector);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpMatrixTimesVector %v4float %matrix %vector");
+}
+
+TEST_F(SpvGeneratorImplTest, Binary_MatrixTimesMatrix_F32) {
+ auto* mat1 = b.FunctionParam("mat1", ty.mat4x3<f32>());
+ auto* mat2 = b.FunctionParam("mat2", ty.mat3x4<f32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({mat1, mat2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Multiply(ty.mat3x3<f32>(), mat1, mat2);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpMatrixTimesMatrix %mat3v3float %mat1 %mat2");
+}
using Comparison = SpvGeneratorImplTestWithParam<BinaryTestCase>;
TEST_P(Comparison, Scalar) {
@@ -134,15 +220,15 @@
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- b.Binary(params.kind, ty.bool_(), MakeScalarValue(params.type),
- MakeScalarValue(params.type));
+ auto* lhs = MakeScalarValue(params.type);
+ auto* rhs = MakeScalarValue(params.type);
+ auto* result = b.Binary(params.kind, ty.bool_(), lhs, rhs);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %bool");
}
TEST_P(Comparison, Vector) {
@@ -150,87 +236,77 @@
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- b.Binary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
+ auto* lhs = MakeVectorValue(params.type);
+ auto* rhs = MakeVectorValue(params.type);
+ auto* result = b.Binary(params.kind, ty.vec2<bool>(), lhs, rhs);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %v2bool");
}
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_I32,
Comparison,
- testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual"},
- BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
- BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan"},
- BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual,
- "OpSGreaterThanEqual"},
- BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan"},
- BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual"}));
+ testing::Values(
+ BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual", "bool"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual", "bool"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan", "bool"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual, "OpSGreaterThanEqual", "bool"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan", "bool"},
+ BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_U32,
Comparison,
- testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual"},
- BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual"},
- BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan"},
- BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual,
- "OpUGreaterThanEqual"},
- BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan"},
- BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual"}));
+ testing::Values(
+ BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual", "bool"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual", "bool"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan", "bool"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual, "OpUGreaterThanEqual", "bool"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan", "bool"},
+ BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_F32,
Comparison,
- testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
- BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
- BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
- BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual,
- "OpFOrdGreaterThanEqual"},
- BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
- BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
+ testing::Values(
+ BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual", "bool"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual", "bool"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan", "bool"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual, "OpFOrdGreaterThanEqual", "bool"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan", "bool"},
+ BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
INSTANTIATE_TEST_SUITE_P(
SpvGeneratorImplTest_Binary_F16,
Comparison,
- testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual"},
- BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"},
- BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"},
- BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual,
- "OpFOrdGreaterThanEqual"},
- BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"},
- BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"}));
-INSTANTIATE_TEST_SUITE_P(
- SpvGeneratorImplTest_Binary_Bool,
- Comparison,
- testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual, "OpLogicalEqual"},
- BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
+ testing::Values(
+ BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual", "bool"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual", "bool"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan", "bool"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual, "OpFOrdGreaterThanEqual", "bool"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan", "bool"},
+ BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual", "bool"}));
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_Bool,
+ Comparison,
+ testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual,
+ "OpLogicalEqual", "bool"},
+ BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual,
+ "OpLogicalNotEqual", "bool"}));
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
auto* func = b.Function("foo", ty.void_());
b.With(func->Block(), [&] {
- auto* a = b.Subtract(ty.i32(), 1_i, 2_i);
- b.Add(ty.i32(), a, a);
+ auto* sub = b.Subtract(ty.i32(), 1_i, 2_i);
+ auto* add = b.Add(ty.i32(), sub, sub);
b.Return(func);
+ mod.SetName(sub, "sub");
+ mod.SetName(add, "add");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%6 = OpTypeInt 32 1
-%7 = OpConstant %6 1
-%8 = OpConstant %6 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpISub %6 %7 %8
-%9 = OpIAdd %6 %5 %5
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%sub = OpISub %int %int_1 %int_2");
+ EXPECT_INST("%add = OpIAdd %int %sub %sub");
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_bitcast_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_bitcast_test.cc
new file mode 100644
index 0000000..f8d37db
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_bitcast_test.cc
@@ -0,0 +1,148 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+namespace tint::writer::spirv {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+/// A parameterized test case.
+struct BitcastCase {
+ /// The input type.
+ TestElementType in;
+ /// The output type.
+ TestElementType out;
+ /// The expected SPIR-V result type name.
+ std::string spirv_type_name;
+};
+std::string PrintCase(testing::TestParamInfo<BitcastCase> cc) {
+ utils::StringStream ss;
+ ss << cc.param.in << "_to_" << cc.param.out;
+ return ss.str();
+}
+
+using Bitcast = SpvGeneratorImplTestWithParam<BitcastCase>;
+TEST_P(Bitcast, Scalar) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeScalarType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeScalarType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(MakeScalarType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ if (params.in == params.out) {
+ EXPECT_INST("OpReturnValue %arg");
+ } else {
+ EXPECT_INST("%result = OpBitcast %" + params.spirv_type_name + " %arg");
+ }
+}
+TEST_P(Bitcast, Vector) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeVectorType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeVectorType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(MakeVectorType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ if (params.in == params.out) {
+ EXPECT_INST("OpReturnValue %arg");
+ } else {
+ EXPECT_INST("%result = OpBitcast %v2" + params.spirv_type_name + " %arg");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Bitcast,
+ testing::Values(
+ // To f32.
+ BitcastCase{kF32, kF32, "float"},
+ BitcastCase{kI32, kF32, "float"},
+ BitcastCase{kU32, kF32, "float"},
+
+ // To f16.
+ BitcastCase{kF16, kF16, "half"},
+
+ // To i32.
+ BitcastCase{kF32, kI32, "int"},
+ BitcastCase{kI32, kI32, "int"},
+ BitcastCase{kU32, kI32, "int"},
+
+ // To u32.
+ BitcastCase{kF32, kU32, "uint"},
+ BitcastCase{kI32, kU32, "uint"},
+ BitcastCase{kU32, kU32, "uint"}),
+ PrintCase);
+
+TEST_F(SpvGeneratorImplTest, Bitcast_u32_to_vec2h) {
+ auto* func = b.Function("foo", ty.vec2<f16>());
+ func->SetParams({b.FunctionParam("arg", ty.u32())});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(ty.vec2<f16>(), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpBitcast %v2half %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Bitcast_vec2i_to_vec4h) {
+ auto* func = b.Function("foo", ty.vec4<f16>());
+ func->SetParams({b.FunctionParam("arg", ty.vec2<i32>())});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(ty.vec4<f16>(), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpBitcast %v4half %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Bitcast_vec2h_to_u32) {
+ auto* func = b.Function("foo", ty.u32());
+ func->SetParams({b.FunctionParam("arg", ty.vec2<f16>())});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(ty.u32(), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpBitcast %uint %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Bitcast_vec4h_to_vec2i) {
+ auto* func = b.Function("foo", ty.vec2<i32>());
+ func->SetParams({b.FunctionParam("arg", ty.vec4<f16>())});
+ b.With(func->Block(), [&] {
+ auto* result = b.Bitcast(ty.vec2<i32>(), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpBitcast %v2int %arg");
+}
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index f9bbcd6..0ab0a39 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -14,7 +14,6 @@
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
-#include "gmock/gmock.h"
#include "src/tint/builtin/function.h"
using namespace tint::number_suffixes; // NOLINT
@@ -43,10 +42,8 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
}
TEST_P(Builtin_1arg, Vector) {
auto params = GetParam();
@@ -57,60 +54,137 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
}
-INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
- Builtin_1arg,
- testing::Values(BuiltinTestCase{kI32, builtin::Function::kAbs, "SAbs"},
- BuiltinTestCase{kF32, builtin::Function::kAbs, "FAbs"}));
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Builtin_1arg,
+ testing::Values(BuiltinTestCase{kI32, builtin::Function::kAbs, "SAbs"},
+ BuiltinTestCase{kF32, builtin::Function::kAbs, "FAbs"},
+ BuiltinTestCase{kF16, builtin::Function::kAbs, "FAbs"},
+ BuiltinTestCase{kF32, builtin::Function::kAcos, "Acos"},
+ BuiltinTestCase{kF16, builtin::Function::kAcos, "Acos"},
+ BuiltinTestCase{kF32, builtin::Function::kAsinh, "Asinh"},
+ BuiltinTestCase{kF16, builtin::Function::kAsinh, "Asinh"},
+ BuiltinTestCase{kF32, builtin::Function::kAcos, "Acos"},
+ BuiltinTestCase{kF16, builtin::Function::kAcos, "Acos"},
+ BuiltinTestCase{kF32, builtin::Function::kAsinh, "Asinh"},
+ BuiltinTestCase{kF16, builtin::Function::kAsinh, "Asinh"},
+ BuiltinTestCase{kF32, builtin::Function::kAtan, "Atan"},
+ BuiltinTestCase{kF16, builtin::Function::kAtan, "Atan"},
+ BuiltinTestCase{kF32, builtin::Function::kAtanh, "Atanh"},
+ BuiltinTestCase{kF16, builtin::Function::kAtanh, "Atanh"},
+ BuiltinTestCase{kF32, builtin::Function::kCos, "Cos"},
+ BuiltinTestCase{kF16, builtin::Function::kCos, "Cos"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdx, "OpDPdx"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdxCoarse, "OpDPdxCoarse"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdxFine, "OpDPdxFine"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdy, "OpDPdy"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdyCoarse, "OpDPdyCoarse"},
+ BuiltinTestCase{kF32, builtin::Function::kDpdyFine, "OpDPdyFine"},
+ BuiltinTestCase{kF32, builtin::Function::kSin, "Sin"},
+ BuiltinTestCase{kF16, builtin::Function::kSin, "Sin"},
+ BuiltinTestCase{kF32, builtin::Function::kTan, "Tan"},
+ BuiltinTestCase{kF16, builtin::Function::kTan, "Tan"},
+ BuiltinTestCase{kF32, builtin::Function::kCosh, "Cosh"},
+ BuiltinTestCase{kF16, builtin::Function::kCosh, "Cosh"},
+ BuiltinTestCase{kF32, builtin::Function::kSinh, "Sinh"},
+ BuiltinTestCase{kF16, builtin::Function::kSinh, "Sinh"},
+ BuiltinTestCase{kF32, builtin::Function::kTanh, "Tanh"},
+ BuiltinTestCase{kF16, builtin::Function::kTanh, "Tanh"}));
// Test that abs of an unsigned value just folds away.
TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) {
auto* func = b.Function("foo", MakeScalarType(kU32));
b.With(func->Block(), [&] {
- auto* result = b.Call(MakeScalarType(kU32), builtin::Function::kAbs, MakeScalarValue(kU32));
+ auto* arg = MakeScalarValue(kU32);
+ auto* result = b.Call(MakeScalarType(kU32), builtin::Function::kAbs, arg);
b.Return(func, result);
+ mod.SetName(arg, "arg");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 0
-%3 = OpTypeFunction %2
-%5 = OpConstant %2 1
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturnValue %5
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %foo = OpFunction %uint None %3
+ %4 = OpLabel
+ OpReturnValue %arg
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) {
auto* func = b.Function("foo", MakeVectorType(kU32));
b.With(func->Block(), [&] {
- auto* result = b.Call(MakeVectorType(kU32), builtin::Function::kAbs, MakeVectorValue(kU32));
+ auto* arg = MakeVectorValue(kU32);
+ auto* result = b.Call(MakeVectorType(kU32), builtin::Function::kAbs, arg);
+ b.Return(func, result);
+ mod.SetName(arg, "arg");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %foo = OpFunction %v2uint None %4
+ %5 = OpLabel
+ OpReturnValue %arg
+ OpFunctionEnd
+)");
+}
+
+// Test that any of an scalar just folds away.
+TEST_F(SpvGeneratorImplTest, Builtin_Any_Scalar) {
+ auto* arg = b.FunctionParam("arg", ty.bool_());
+ auto* func = b.Function("foo", ty.bool_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.bool_(), builtin::Function::kAny, arg);
b.Return(func, result);
});
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpReturnValue %arg");
+}
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%3 = OpTypeInt 32 0
-%2 = OpTypeVector %3 2
-%4 = OpTypeFunction %2
-%7 = OpConstant %3 42
-%8 = OpConstant %3 10
-%6 = OpConstantComposite %2 %7 %8
-%1 = OpFunction %2 None %4
-%5 = OpLabel
-OpReturnValue %6
-OpFunctionEnd
-)");
+TEST_F(SpvGeneratorImplTest, Builtin_Any_Vector) {
+ auto* arg = b.FunctionParam("arg", ty.vec4<bool>());
+ auto* func = b.Function("foo", ty.bool_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.bool_(), builtin::Function::kAny, arg);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpAny %bool %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Length_vec4f) {
+ auto* arg = b.FunctionParam("arg", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.f32(), builtin::Function::kLength, arg);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %float %8 Length %arg");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Normalize_vec4f) {
+ auto* arg = b.FunctionParam("arg", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), builtin::Function::kNormalize, arg);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %v4float %8 Normalize %arg");
}
// Tests for builtins with the signature: T = func(T, T)
@@ -125,10 +199,8 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
}
TEST_P(Builtin_2arg, Vector) {
auto params = GetParam();
@@ -140,19 +212,98 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
}
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
Builtin_2arg,
- testing::Values(BuiltinTestCase{kF32, builtin::Function::kMax, "FMax"},
+ testing::Values(BuiltinTestCase{kF32, builtin::Function::kAtan2, "Atan2"},
+ BuiltinTestCase{kF32, builtin::Function::kMax, "FMax"},
BuiltinTestCase{kI32, builtin::Function::kMax, "SMax"},
BuiltinTestCase{kU32, builtin::Function::kMax, "UMax"},
BuiltinTestCase{kF32, builtin::Function::kMin, "FMin"},
BuiltinTestCase{kI32, builtin::Function::kMin, "SMin"},
BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"}));
+TEST_F(SpvGeneratorImplTest, Builtin_Cross_vec3f) {
+ auto* arg1 = b.FunctionParam("arg1", ty.vec3<f32>());
+ auto* arg2 = b.FunctionParam("arg2", ty.vec3<f32>());
+ auto* func = b.Function("foo", ty.vec3<f32>());
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(ty.vec3<f32>(), builtin::Function::kCross, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %v3float %9 Cross %arg1 %arg2");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Distance_vec2f) {
+ auto* arg1 = b.FunctionParam("arg1", MakeVectorType(kF32));
+ auto* arg2 = b.FunctionParam("arg2", MakeVectorType(kF32));
+ auto* func = b.Function("foo", MakeScalarType(kF32));
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(MakeScalarType(kF32), builtin::Function::kDistance, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %float %9 Distance %arg1 %arg2");
+}
+
+TEST_F(SpvGeneratorImplTest, Builtin_Distance_vec3h) {
+ auto* arg1 = b.FunctionParam("arg1", MakeVectorType(kF16));
+ auto* arg2 = b.FunctionParam("arg2", MakeVectorType(kF16));
+ auto* func = b.Function("foo", MakeScalarType(kF16));
+ func->SetParams({arg1, arg2});
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(MakeScalarType(kF16), builtin::Function::kDistance, arg1, arg2);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpExtInst %half %9 Distance %arg1 %arg2");
+}
+
+// Tests for builtins with the signature: T = func(T, T, T)
+using Builtin_3arg = SpvGeneratorImplTestWithParam<BuiltinTestCase>;
+TEST_P(Builtin_3arg, Scalar) {
+ auto params = GetParam();
+
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ b.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type),
+ MakeScalarValue(params.type), MakeScalarValue(params.type));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
+}
+TEST_P(Builtin_3arg, Vector) {
+ auto params = GetParam();
+
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ b.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type),
+ MakeVectorValue(params.type), MakeVectorValue(params.type));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.spirv_inst);
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Builtin_3arg,
+ testing::Values(BuiltinTestCase{kF32, builtin::Function::kClamp, "NClamp"},
+ BuiltinTestCase{kI32, builtin::Function::kClamp, "SClamp"},
+ BuiltinTestCase{kU32, builtin::Function::kClamp,
+ "UClamp"}));
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
index 40fae96..e526f0b 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -22,46 +22,41 @@
TEST_F(SpvGeneratorImplTest, Constant_Bool) {
generator_.Constant(b.Constant(true));
generator_.Constant(b.Constant(false));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeBool
-%1 = OpConstantTrue %2
-%3 = OpConstantFalse %2
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%true = OpConstantTrue %bool");
+ EXPECT_INST("%false = OpConstantFalse %bool");
}
TEST_F(SpvGeneratorImplTest, Constant_I32) {
generator_.Constant(b.Constant(i32(42)));
generator_.Constant(b.Constant(i32(-1)));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 1
-%1 = OpConstant %2 42
-%3 = OpConstant %2 -1
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%int_42 = OpConstant %int 42");
+ EXPECT_INST("%int_n1 = OpConstant %int -1");
}
TEST_F(SpvGeneratorImplTest, Constant_U32) {
generator_.Constant(b.Constant(u32(42)));
generator_.Constant(b.Constant(u32(4000000000)));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 0
-%1 = OpConstant %2 42
-%3 = OpConstant %2 4000000000
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%uint_42 = OpConstant %uint 42");
+ EXPECT_INST("%uint_4000000000 = OpConstant %uint 4000000000");
}
TEST_F(SpvGeneratorImplTest, Constant_F32) {
generator_.Constant(b.Constant(f32(42)));
generator_.Constant(b.Constant(f32(-1)));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
-%1 = OpConstant %2 42
-%3 = OpConstant %2 -1
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%float_42 = OpConstant %float 42");
+ EXPECT_INST("%float_n1 = OpConstant %float -1");
}
TEST_F(SpvGeneratorImplTest, Constant_F16) {
generator_.Constant(b.Constant(f16(42)));
generator_.Constant(b.Constant(f16(-1)));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 16
-%1 = OpConstant %2 0x1.5p+5
-%3 = OpConstant %2 -0x1p+0
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%half_0x1_5p_5 = OpConstant %half 0x1.5p+5");
+ EXPECT_INST("%half_n0x1p_0 = OpConstant %half -0x1p+0");
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
@@ -71,12 +66,8 @@
utils::Vector{const_bool(true), const_bool(false), const_bool(false), const_bool(true)});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool
-%2 = OpTypeVector %3 4
-%4 = OpConstantTrue %3
-%5 = OpConstantFalse %3
-%1 = OpConstantComposite %2 %4 %5 %5 %4
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %v4bool %true %false %false %true");
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
@@ -84,12 +75,8 @@
auto* v = mod.constant_values.Composite(ty.vec2(ty.i32()),
utils::Vector{const_i32(42), const_i32(-1)});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
-%2 = OpTypeVector %3 2
-%4 = OpConstant %3 42
-%5 = OpConstant %3 -1
-%1 = OpConstantComposite %2 %4 %5
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %v2int %int_42 %int_n1");
}
TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
@@ -97,13 +84,8 @@
auto* v = mod.constant_values.Composite(
ty.vec3(ty.u32()), utils::Vector{const_u32(42), const_u32(0), const_u32(4000000000)});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0
-%2 = OpTypeVector %3 3
-%4 = OpConstant %3 42
-%5 = OpConstant %3 0
-%6 = OpConstant %3 4000000000
-%1 = OpConstantComposite %2 %4 %5 %6
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %v3uint %uint_42 %uint_0 %uint_4000000000");
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
@@ -112,14 +94,8 @@
ty.vec4(ty.f32()),
utils::Vector{const_f32(42), const_f32(0), const_f32(0.25), const_f32(-1)});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32
-%2 = OpTypeVector %3 4
-%4 = OpConstant %3 42
-%5 = OpConstant %3 0
-%6 = OpConstant %3 0.25
-%7 = OpConstant %3 -1
-%1 = OpConstantComposite %2 %4 %5 %6 %7
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %v4float %float_42 %float_0 %float_0_25 %float_n1");
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
@@ -127,12 +103,8 @@
auto* v = mod.constant_values.Composite(ty.vec2(ty.f16()),
utils::Vector{const_f16(42), const_f16(0.25)});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16
-%2 = OpTypeVector %3 2
-%4 = OpConstant %3 0x1.5p+5
-%5 = OpConstant %3 0x1p-2
-%1 = OpConstantComposite %2 %4 %5
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %v2half %half_0x1_5p_5 %half_0x1pn2");
}
TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) {
@@ -147,18 +119,17 @@
ty.vec3(f32), utils::Vector{const_f32(-42), const_f32(0), const_f32(-0.25)}),
});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
-%3 = OpTypeVector %4 3
-%2 = OpTypeMatrix %3 2
-%6 = OpConstant %4 42
-%7 = OpConstant %4 -1
-%8 = OpConstant %4 0.25
-%5 = OpConstantComposite %3 %6 %7 %8
-%10 = OpConstant %4 -42
-%11 = OpConstant %4 0
-%12 = OpConstant %4 -0.25
-%9 = OpConstantComposite %3 %10 %11 %12
-%1 = OpConstantComposite %2 %5 %9
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %float_42 = OpConstant %float 42
+ %float_n1 = OpConstant %float -1
+ %float_0_25 = OpConstant %float 0.25
+ %5 = OpConstantComposite %v3float %float_42 %float_n1 %float_0_25
+ %float_n42 = OpConstant %float -42
+ %float_0 = OpConstant %float 0
+%float_n0_25 = OpConstant %float -0.25
+ %9 = OpConstantComposite %v3float %float_n42 %float_0 %float_n0_25
+ %1 = OpConstantComposite %mat2v3float %5 %9
)");
}
@@ -177,21 +148,20 @@
ty.vec2(f16), utils::Vector{const_f16(0.5), const_f16(-0)}),
});
generator_.Constant(b.Constant(v));
- EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 16
-%3 = OpTypeVector %4 2
-%2 = OpTypeMatrix %3 4
-%6 = OpConstant %4 0x1.5p+5
-%7 = OpConstant %4 -0x1p+0
-%5 = OpConstantComposite %3 %6 %7
-%9 = OpConstant %4 0x0p+0
-%10 = OpConstant %4 0x1p-2
-%8 = OpConstantComposite %3 %9 %10
-%12 = OpConstant %4 -0x1.5p+5
-%13 = OpConstant %4 0x1p+0
-%11 = OpConstantComposite %3 %12 %13
-%15 = OpConstant %4 0x1p-1
-%14 = OpConstantComposite %3 %15 %9
-%1 = OpConstantComposite %2 %5 %8 %11 %14
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+%half_0x1_5p_5 = OpConstant %half 0x1.5p+5
+%half_n0x1p_0 = OpConstant %half -0x1p+0
+ %5 = OpConstantComposite %v2half %half_0x1_5p_5 %half_n0x1p_0
+%half_0x0p_0 = OpConstant %half 0x0p+0
+%half_0x1pn2 = OpConstant %half 0x1p-2
+ %8 = OpConstantComposite %v2half %half_0x0p_0 %half_0x1pn2
+%half_n0x1_5p_5 = OpConstant %half -0x1.5p+5
+%half_0x1p_0 = OpConstant %half 0x1p+0
+ %11 = OpConstantComposite %v2half %half_n0x1_5p_5 %half_0x1p_0
+%half_0x1pn1 = OpConstant %half 0x1p-1
+ %14 = OpConstantComposite %v2half %half_0x1pn1 %half_0x0p_0
+ %1 = OpConstantComposite %mat4v2half %5 %8 %11 %14
)");
}
@@ -204,16 +174,8 @@
mod.constant_values.Get(4_i),
});
generator_.Constant(b.Constant(arr));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
-%5 = OpTypeInt 32 0
-%4 = OpConstant %5 4
-%2 = OpTypeArray %3 %4
-%6 = OpConstant %3 1
-%7 = OpConstant %3 2
-%8 = OpConstant %3 3
-%9 = OpConstant %3 4
-%1 = OpConstantComposite %2 %6 %7 %8 %9
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4");
}
TEST_F(SpvGeneratorImplTest, Constant_Array_Array_I32) {
@@ -231,21 +193,14 @@
inner,
});
generator_.Constant(b.Constant(arr));
- EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeInt 32 1
-%6 = OpTypeInt 32 0
-%5 = OpConstant %6 4
-%3 = OpTypeArray %4 %5
-%2 = OpTypeArray %3 %5
-%8 = OpConstant %4 1
-%9 = OpConstant %4 2
-%10 = OpConstant %4 3
-%11 = OpConstant %4 4
-%7 = OpConstantComposite %3 %8 %9 %10 %11
-%1 = OpConstantComposite %2 %7 %7 %7 %7
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %7 = OpConstantComposite %_arr_int_uint_4 %int_1 %int_2 %int_3 %int_4
+ %1 = OpConstantComposite %_arr__arr_int_uint_4_uint_4 %7 %7 %7 %7
)");
}
-TEST_F(SpvGeneratorImplTest, Struct) {
+TEST_F(SpvGeneratorImplTest, Constant_Struct) {
auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
{mod.symbols.New("a"), ty.i32()},
{mod.symbols.New("b"), ty.u32()},
@@ -257,15 +212,8 @@
mod.constant_values.Get(3_f),
});
generator_.Constant(b.Constant(str));
- EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
-%4 = OpTypeInt 32 0
-%5 = OpTypeFloat 32
-%2 = OpTypeStruct %3 %4 %5
-%6 = OpConstant %3 1
-%7 = OpConstant %4 2
-%8 = OpConstant %5 3
-%1 = OpConstantComposite %2 %6 %7 %8
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpConstantComposite %MyStruct %int_1 %uint_2 %float_3");
}
// Test that we do not emit the same constant more than once.
@@ -273,9 +221,8 @@
generator_.Constant(b.Constant(i32(42)));
generator_.Constant(b.Constant(i32(42)));
generator_.Constant(b.Constant(i32(42)));
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 1
-%1 = OpConstant %2 42
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%int_42 = OpConstant %int 42");
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
index 7ba7cff..0f33beb 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
@@ -23,93 +23,54 @@
TEST_F(SpvGeneratorImplTest, Construct_Vector) {
auto* func = b.Function("foo", ty.vec4<i32>());
func->SetParams({
- b.FunctionParam(ty.i32()),
- b.FunctionParam(ty.i32()),
- b.FunctionParam(ty.i32()),
- b.FunctionParam(ty.i32()),
+ b.FunctionParam("a", ty.i32()),
+ b.FunctionParam("b", ty.i32()),
+ b.FunctionParam("c", ty.i32()),
+ b.FunctionParam("d", ty.i32()),
+ });
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.vec4<i32>(), func->Params());
+ b.Return(func, result);
+ mod.SetName(result, "result");
});
- b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.vec4<i32>(), func->Params())); });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%3 = OpTypeInt 32 1
-%2 = OpTypeVector %3 4
-%8 = OpTypeFunction %2 %3 %3 %3 %3
-%1 = OpFunction %2 None %8
-%4 = OpFunctionParameter %3
-%5 = OpFunctionParameter %3
-%6 = OpFunctionParameter %3
-%7 = OpFunctionParameter %3
-%9 = OpLabel
-%10 = OpCompositeConstruct %2 %4 %5 %6 %7
-OpReturnValue %10
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpCompositeConstruct %v4int %a %b %c %d");
}
TEST_F(SpvGeneratorImplTest, Construct_Matrix) {
auto* func = b.Function("foo", ty.mat3x4<f32>());
func->SetParams({
- b.FunctionParam(ty.vec4<f32>()),
- b.FunctionParam(ty.vec4<f32>()),
- b.FunctionParam(ty.vec4<f32>()),
+ b.FunctionParam("a", ty.vec4<f32>()),
+ b.FunctionParam("b", ty.vec4<f32>()),
+ b.FunctionParam("c", ty.vec4<f32>()),
+ });
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.mat3x4<f32>(), func->Params());
+ b.Return(func, result);
+ mod.SetName(result, "result");
});
- b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.mat3x4<f32>(), func->Params())); });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%4 = OpTypeFloat 32
-%3 = OpTypeVector %4 4
-%2 = OpTypeMatrix %3 3
-%8 = OpTypeFunction %2 %3 %3 %3
-%1 = OpFunction %2 None %8
-%5 = OpFunctionParameter %3
-%6 = OpFunctionParameter %3
-%7 = OpFunctionParameter %3
-%9 = OpLabel
-%10 = OpCompositeConstruct %2 %5 %6 %7
-OpReturnValue %10
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpCompositeConstruct %mat3v4float %a %b %c");
}
TEST_F(SpvGeneratorImplTest, Construct_Array) {
auto* func = b.Function("foo", ty.array<f32, 4>());
func->SetParams({
- b.FunctionParam(ty.f32()),
- b.FunctionParam(ty.f32()),
- b.FunctionParam(ty.f32()),
- b.FunctionParam(ty.f32()),
+ b.FunctionParam("a", ty.f32()),
+ b.FunctionParam("b", ty.f32()),
+ b.FunctionParam("c", ty.f32()),
+ b.FunctionParam("d", ty.f32()),
+ });
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(ty.array<f32, 4>(), func->Params());
+ b.Return(func, result);
+ mod.SetName(result, "result");
});
- b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.array<f32, 4>(), func->Params())); });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpDecorate %2 ArrayStride 4
-%3 = OpTypeFloat 32
-%5 = OpTypeInt 32 0
-%4 = OpConstant %5 4
-%2 = OpTypeArray %3 %4
-%10 = OpTypeFunction %2 %3 %3 %3 %3
-%1 = OpFunction %2 None %10
-%6 = OpFunctionParameter %3
-%7 = OpFunctionParameter %3
-%8 = OpFunctionParameter %3
-%9 = OpFunctionParameter %3
-%11 = OpLabel
-%12 = OpCompositeConstruct %2 %6 %7 %8 %9
-OpReturnValue %12
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpCompositeConstruct %_arr_float_uint_4 %a %b %c %d");
}
TEST_F(SpvGeneratorImplTest, Construct_Struct) {
@@ -119,42 +80,20 @@
{mod.symbols.Register("b"), ty.u32()},
{mod.symbols.Register("c"), ty.vec4<f32>()},
});
-
auto* func = b.Function("foo", str);
func->SetParams({
- b.FunctionParam(ty.i32()),
- b.FunctionParam(ty.u32()),
- b.FunctionParam(ty.vec4<f32>()),
+ b.FunctionParam("a", ty.i32()),
+ b.FunctionParam("b", ty.u32()),
+ b.FunctionParam("c", ty.vec4<f32>()),
+ });
+ b.With(func->Block(), [&] {
+ auto* result = b.Construct(str, func->Params());
+ b.Return(func, result);
+ mod.SetName(result, "result");
});
- b.With(func->Block(), [&] { b.Return(func, b.Construct(str, func->Params())); });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpMemberName %2 0 "a"
-OpMemberName %2 1 "b"
-OpMemberName %2 2 "c"
-OpName %2 "MyStruct"
-OpMemberDecorate %2 0 Offset 0
-OpMemberDecorate %2 1 Offset 4
-OpMemberDecorate %2 2 Offset 16
-%3 = OpTypeInt 32 1
-%4 = OpTypeInt 32 0
-%6 = OpTypeFloat 32
-%5 = OpTypeVector %6 4
-%2 = OpTypeStruct %3 %4 %5
-%10 = OpTypeFunction %2 %3 %4 %5
-%1 = OpFunction %2 None %10
-%7 = OpFunctionParameter %3
-%8 = OpFunctionParameter %4
-%9 = OpFunctionParameter %5
-%11 = OpLabel
-%12 = OpCompositeConstruct %2 %7 %8 %9
-OpReturnValue %12
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpCompositeConstruct %MyStruct %a %b %c");
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc
new file mode 100644
index 0000000..22481bb
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_convert_test.cc
@@ -0,0 +1,102 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+namespace tint::writer::spirv {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+/// A parameterized test case.
+struct ConvertCase {
+ /// The input type.
+ TestElementType in;
+ /// The output type.
+ TestElementType out;
+ /// The expected SPIR-V instruction.
+ std::string spirv_inst;
+ /// The expected SPIR-V result type name.
+ std::string spirv_type_name;
+};
+std::string PrintCase(testing::TestParamInfo<ConvertCase> cc) {
+ utils::StringStream ss;
+ ss << cc.param.in << "_to_" << cc.param.out;
+ return ss.str();
+}
+
+using Convert = SpvGeneratorImplTestWithParam<ConvertCase>;
+TEST_P(Convert, Scalar) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeScalarType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeScalarType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Convert(MakeScalarType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %" + params.spirv_type_name + " %arg");
+}
+TEST_P(Convert, Vector) {
+ auto& params = GetParam();
+ auto* func = b.Function("foo", MakeVectorType(params.out));
+ func->SetParams({b.FunctionParam("arg", MakeVectorType(params.in))});
+ b.With(func->Block(), [&] {
+ auto* result = b.Convert(MakeVectorType(params.out), func->Params()[0]);
+ b.Return(func, result);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %v2" + params.spirv_type_name + " %arg");
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Convert,
+ testing::Values(
+ // To f32.
+ ConvertCase{kF16, kF32, "OpFConvert", "float"},
+ ConvertCase{kI32, kF32, "OpConvertSToF", "float"},
+ ConvertCase{kU32, kF32, "OpConvertUToF", "float"},
+ ConvertCase{kBool, kF32, "OpSelect", "float"},
+
+ // To f16.
+ ConvertCase{kF32, kF16, "OpFConvert", "half"},
+ ConvertCase{kI32, kF16, "OpConvertSToF", "half"},
+ ConvertCase{kU32, kF16, "OpConvertUToF", "half"},
+ ConvertCase{kBool, kF16, "OpSelect", "half"},
+
+ // To i32.
+ ConvertCase{kF32, kI32, "OpConvertFToS", "int"},
+ ConvertCase{kF16, kI32, "OpConvertFToS", "int"},
+ ConvertCase{kU32, kI32, "OpBitcast", "int"},
+ ConvertCase{kBool, kI32, "OpSelect", "int"},
+
+ // To u32.
+ ConvertCase{kF32, kU32, "OpConvertFToU", "uint"},
+ ConvertCase{kF16, kU32, "OpConvertFToU", "uint"},
+ ConvertCase{kI32, kU32, "OpBitcast", "uint"},
+ ConvertCase{kBool, kU32, "OpSelect", "uint"},
+
+ // To bool.
+ ConvertCase{kF32, kBool, "OpFUnordNotEqual", "bool"},
+ ConvertCase{kF16, kBool, "OpFUnordNotEqual", "bool"},
+ ConvertCase{kI32, kBool, "OpINotEqual", "bool"},
+ ConvertCase{kU32, kBool, "OpINotEqual", "bool"}),
+ PrintCase);
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 8450235..5251157 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -17,151 +17,215 @@
namespace tint::writer::spirv {
namespace {
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
TEST_F(SpvGeneratorImplTest, Function_Empty) {
auto* func = b.Function("foo", ty.void_());
- func->Block()->Append(b.Return(func));
+ b.With(func->Block(), [&] { //
+ b.Return(func);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %foo = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
- auto* func = b.Function("foo", ty.void_());
- func->Block()->Append(b.Return(func));
+ auto* func_a = b.Function("func_a", ty.void_());
+ b.With(func_a->Block(), [&] { //
+ b.Return(func_a);
+ });
+ auto* func_b = b.Function("func_b", ty.void_());
+ b.With(func_b->Block(), [&] { //
+ b.Return(func_b);
+ });
+ auto* func_c = b.Function("func_c", ty.void_());
+ b.With(func_c->Block(), [&] { //
+ b.Return(func_c);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ ; Types, variables and constants
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
- generator_.EmitFunction(func);
- generator_.EmitFunction(func);
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeVoid
-%3 = OpTypeFunction %2
+ ; Function func_a
+ %func_a = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; Function func_b
+ %func_b = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; Function func_c
+ %func_c = OpFunction %void None %3
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func =
b.Function("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- func->Block()->Append(b.Return(func));
+ b.With(func->Block(), [&] { //
+ b.Return(func);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 32 4 1
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint GLCompute %1 "main"
-OpExecutionMode %1 LocalSize 32 4 1
-OpName %1 "main"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
+ ; Debug Information
+ OpName %main "main" ; id %1
+
+ ; Types, variables and constants
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+
+ ; Function main
+ %main = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kFragment);
- func->Block()->Append(b.Return(func));
+ b.With(func->Block(), [&] { //
+ b.Return(func);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Fragment %1 "main"
-OpExecutionMode %1 OriginUpperLeft
-OpName %1 "main"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
+ ; Debug Information
+ OpName %main "main" ; id %1
+
+ ; Types, variables and constants
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+
+ ; Function main
+ %main = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kVertex);
- func->Block()->Append(b.Return(func));
+ b.With(func->Block(), [&] { //
+ b.Return(func);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpEntryPoint Vertex %main "main"
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Vertex %1 "main"
-OpName %1 "main"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
+ ; Debug Information
+ OpName %main "main" ; id %1
+
+ ; Types, variables and constants
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+
+ ; Function main
+ %main = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
auto* f1 = b.Function("main1", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- f1->Block()->Append(b.Return(f1));
+ b.With(f1->Block(), [&] { //
+ b.Return(f1);
+ });
auto* f2 = b.Function("main2", ty.void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
- f2->Block()->Append(b.Return(f2));
+ b.With(f2->Block(), [&] { //
+ b.Return(f2);
+ });
auto* f3 = b.Function("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
- f3->Block()->Append(b.Return(f3));
+ b.With(f3->Block(), [&] { //
+ b.Return(f3);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpEntryPoint GLCompute %main1 "main1"
+ OpEntryPoint GLCompute %main2 "main2"
+ OpEntryPoint Fragment %main3 "main3"
+ OpExecutionMode %main1 LocalSize 32 4 1
+ OpExecutionMode %main2 LocalSize 8 2 16
+ OpExecutionMode %main3 OriginUpperLeft
- generator_.EmitFunction(f1);
- generator_.EmitFunction(f2);
- generator_.EmitFunction(f3);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint GLCompute %1 "main1"
-OpEntryPoint GLCompute %5 "main2"
-OpEntryPoint Fragment %7 "main3"
-OpExecutionMode %1 LocalSize 32 4 1
-OpExecutionMode %5 LocalSize 8 2 16
-OpExecutionMode %7 OriginUpperLeft
-OpName %1 "main1"
-OpName %5 "main2"
-OpName %7 "main3"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
-%5 = OpFunction %2 None %3
-%6 = OpLabel
-OpReturn
-OpFunctionEnd
-%7 = OpFunction %2 None %3
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ; Debug Information
+ OpName %main1 "main1" ; id %1
+ OpName %main2 "main2" ; id %5
+ OpName %main3 "main3" ; id %7
+
+ ; Types, variables and constants
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+
+ ; Function main1
+ %main1 = OpFunction %void None %3
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; Function main2
+ %main2 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+ ; Function main3
+ %main3 = OpFunction %void None %3
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_ReturnValue) {
auto* func = b.Function("foo", ty.i32());
- func->Block()->Append(b.Return(func, i32(42)));
+ b.With(func->Block(), [&] { //
+ b.Return(func, 42_i);
+ });
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %3 = OpTypeFunction %int
+ %int_42 = OpConstant %int 42
+ %void = OpTypeVoid
+ %8 = OpTypeFunction %void
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%5 = OpConstant %2 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturnValue %5
-OpFunctionEnd
+ ; Function foo
+ %foo = OpFunction %int None %3
+ %4 = OpLabel
+ OpReturnValue %int_42
+ OpFunctionEnd
)");
}
@@ -177,97 +241,61 @@
b.Return(func, result);
});
- ASSERT_TRUE(IRIsValid()) << Error();
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %5 = OpTypeFunction %int %int %int
+ %void = OpTypeVoid
+ %10 = OpTypeFunction %void
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpName %3 "x"
-OpName %4 "y"
-%2 = OpTypeInt 32 1
-%5 = OpTypeFunction %2 %2 %2
-%1 = OpFunction %2 None %5
-%3 = OpFunctionParameter %2
-%4 = OpFunctionParameter %2
-%6 = OpLabel
-%7 = OpIAdd %2 %3 %4
-OpReturnValue %7
-OpFunctionEnd
+ ; Function foo
+ %foo = OpFunction %int None %5
+ %x = OpFunctionParameter %int
+ %y = OpFunctionParameter %int
+ %6 = OpLabel
+ %7 = OpIAdd %int %x %y
+ OpReturnValue %7
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_Call) {
- auto* i32_ty = ty.i32();
- auto* x = b.FunctionParam(i32_ty);
- auto* y = b.FunctionParam(i32_ty);
- auto* foo = b.Function("foo", i32_ty);
+ auto* i32 = ty.i32();
+ auto* x = b.FunctionParam("x", i32);
+ auto* y = b.FunctionParam("y", i32);
+ auto* foo = b.Function("foo", i32);
foo->SetParams({x, y});
b.With(foo->Block(), [&] {
- auto* result = b.Add(i32_ty, x, y);
+ auto* result = b.Add(i32, x, y);
b.Return(foo, result);
});
auto* bar = b.Function("bar", ty.void_());
b.With(bar->Block(), [&] {
- b.Call(i32_ty, foo, i32(2), i32(3));
+ auto* result = b.Call(i32, foo, 2_i, 3_i);
b.Return(bar);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(foo);
- generator_.EmitFunction(bar);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpName %8 "bar"
-%2 = OpTypeInt 32 1
-%5 = OpTypeFunction %2 %2 %2
-%9 = OpTypeVoid
-%10 = OpTypeFunction %9
-%13 = OpConstant %2 2
-%14 = OpConstant %2 3
-%1 = OpFunction %2 None %5
-%3 = OpFunctionParameter %2
-%4 = OpFunctionParameter %2
-%6 = OpLabel
-%7 = OpIAdd %2 %3 %4
-OpReturnValue %7
-OpFunctionEnd
-%8 = OpFunction %9 None %10
-%11 = OpLabel
-%12 = OpFunctionCall %2 %1 %13 %14
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpFunctionCall %int %foo %int_2 %int_3");
}
TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
auto* foo = b.Function("foo", ty.void_());
- foo->Block()->Append(b.Return(foo));
+ b.With(foo->Block(), [&] { //
+ b.Return(foo);
+ });
auto* bar = b.Function("bar", ty.void_());
b.With(bar->Block(), [&] {
- b.Call(ty.void_(), foo, utils::Empty);
+ auto* result = b.Call(ty.void_(), foo);
b.Return(bar);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(foo);
- generator_.EmitFunction(bar);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpName %5 "bar"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpReturn
-OpFunctionEnd
-%5 = OpFunction %2 None %3
-%6 = OpLabel
-%7 = OpFunctionCall %2 %1
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpFunctionCall %void %foo");
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index 892409d..b548745 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -21,316 +21,270 @@
TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ b.With(i->True(), [&] { //
+ b.ExitIf(i);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i);
+ });
+ b.Return(func);
+ });
- auto* i = b.If(true);
- i->True()->Append(b.ExitIf(i));
- i->False()->Append(b.ExitIf(i));
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeBool
-%6 = OpConstantTrue %7
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %6 %5 %5
-%5 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %5 %5
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
auto* func = b.Function("foo", ty.void_());
-
- auto* i = b.If(true);
- i->False()->Append(b.ExitIf(i));
-
- b.With(i->True(), [&] {
- b.Add(ty.i32(), 1_i, 1_i);
- b.ExitIf(i);
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ b.With(i->True(), [&] {
+ b.Add(ty.i32(), 1_i, 1_i);
+ b.ExitIf(i);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i);
+ });
+ b.Return(func);
});
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%8 = OpTypeBool
-%7 = OpConstantTrue %8
-%10 = OpTypeInt 32 1
-%11 = OpConstant %10 1
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %7 %6 %5
-%6 = OpLabel
-%9 = OpIAdd %10 %11 %11
-OpBranch %5
-%5 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %6 %5
+ %6 = OpLabel
+ %9 = OpIAdd %int %int_1 %int_1
+ OpBranch %5
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
auto* func = b.Function("foo", ty.void_());
-
- auto* i = b.If(true);
- i->True()->Append(b.ExitIf(i));
-
- b.With(i->False(), [&] {
- b.Add(ty.i32(), 1_i, 1_i);
- b.ExitIf(i);
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ b.With(i->True(), [&] { //
+ b.ExitIf(i);
+ });
+ b.With(i->False(), [&] {
+ b.Add(ty.i32(), 1_i, 1_i);
+ b.ExitIf(i);
+ });
+ b.Return(func);
});
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%8 = OpTypeBool
-%7 = OpConstantTrue %8
-%10 = OpTypeInt 32 1
-%11 = OpConstant %10 1
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %7 %5 %6
-%6 = OpLabel
-%9 = OpIAdd %10 %11 %11
-OpBranch %5
-%5 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %5 %6
+ %6 = OpLabel
+ %9 = OpIAdd %int %int_1 %int_1
+ OpBranch %5
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ b.With(i->True(), [&] { //
+ b.Return(func);
+ });
+ b.With(i->False(), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
- auto* i = b.If(true);
- i->True()->Append(b.Return(func));
- i->False()->Append(b.Return(func));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Unreachable());
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpReturn
-%7 = OpLabel
-OpReturn
-%5 = OpLabel
-OpUnreachable
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %5 %5
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()));
+ b.With(i->True(), [&] { //
+ b.ExitIf(i, 10_i);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i, 20_i);
+ });
+ b.Return(func, i);
+ });
- auto* i = b.If(true);
- i->SetResults(b.InstructionResult(ty.i32()));
- i->True()->Append(b.ExitIf(i, 10_i));
- i->False()->Append(b.ExitIf(i, 20_i));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func, i));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%11 = OpConstant %2 10
-%12 = OpConstant %2 20
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpBranch %5
-%7 = OpLabel
-OpBranch %5
-%5 = OpLabel
-%10 = OpPhi %2 %11 %6 %12 %7
-OpReturnValue %10
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %6 %7
+ %6 = OpLabel
+ OpBranch %5
+ %7 = OpLabel
+ OpBranch %5
+ %5 = OpLabel
+ %10 = OpPhi %int %int_10 %6 %int_20 %7
+ OpReturnValue %10
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_TrueReturn) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()));
+ b.With(i->True(), [&] { //
+ b.Return(func, 42_i);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i, 20_i);
+ });
+ b.Return(func, i);
+ });
- auto* i = b.If(true);
- i->SetResults(b.InstructionResult(ty.i32()));
- i->True()->Append(b.Return(func, 42_i));
- i->False()->Append(b.ExitIf(i, 20_i));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func, i));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%10 = OpConstant %2 42
-%12 = OpConstant %2 20
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpReturnValue %10
-%7 = OpLabel
-OpBranch %5
-%5 = OpLabel
-%11 = OpPhi %2 %12 %7
-OpReturnValue %11
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%17 = OpUndef %int");
+ EXPECT_INST(R"(
+ OpSelectionMerge %11 None
+ OpBranchConditional %true %12 %13
+ %12 = OpLabel
+ OpStore %continue_execution %false
+ OpStore %return_value %int_42
+ OpBranch %11
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %16 = OpPhi %int %17 %12 %int_20 %13
+ %19 = OpLoad %bool %continue_execution
+ OpSelectionMerge %20 None
+ OpBranchConditional %19 %21 %20
+ %21 = OpLabel
+ OpStore %return_value %16
+ OpBranch %20
+ %20 = OpLabel
+ %22 = OpLoad %int %return_value
+ OpReturnValue %22
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_FalseReturn) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()));
+ b.With(i->True(), [&] { //
+ b.ExitIf(i, 10_i);
+ });
+ b.With(i->False(), [&] { //
+ b.Return(func, 42_i);
+ });
+ b.Return(func, i);
+ });
- auto* i = b.If(true);
- i->SetResults(b.InstructionResult(ty.i32()));
- i->True()->Append(b.ExitIf(i, 10_i));
- i->False()->Append(b.Return(func, 42_i));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func, i));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%10 = OpConstant %2 42
-%12 = OpConstant %2 10
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpBranch %5
-%7 = OpLabel
-OpReturnValue %10
-%5 = OpLabel
-%11 = OpPhi %2 %12 %6
-OpReturnValue %11
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%18 = OpUndef %int");
+ EXPECT_INST(R"(
+ OpSelectionMerge %11 None
+ OpBranchConditional %true %12 %13
+ %12 = OpLabel
+ OpBranch %11
+ %13 = OpLabel
+ OpStore %continue_execution %false
+ OpStore %return_value %int_42
+ OpBranch %11
+ %11 = OpLabel
+ %16 = OpPhi %int %int_10 %12 %18 %13
+ %19 = OpLoad %bool %continue_execution
+ OpSelectionMerge %20 None
+ OpBranchConditional %19 %21 %20
+ %21 = OpLabel
+ OpStore %return_value %16
+ OpBranch %20
+ %20 = OpLabel
+ %22 = OpLoad %int %return_value
+ OpReturnValue %22
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_Phi_MultipleValue_0) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ b.With(i->True(), [&] { //
+ b.ExitIf(i, 10_i, true);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i, 20_i, false);
+ });
+ b.Return(func, i->Result(0));
+ });
- auto* i = b.If(true);
- i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
- i->True()->Append(b.ExitIf(i, 10_i, true));
- i->False()->Append(b.ExitIf(i, 20_i, false));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func, i->Result(0)));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%11 = OpConstant %2 10
-%12 = OpConstant %2 20
-%14 = OpConstantFalse %9
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpBranch %5
-%7 = OpLabel
-OpBranch %5
-%5 = OpLabel
-%10 = OpPhi %2 %11 %6 %12 %7
-%13 = OpPhi %9 %8 %6 %14 %7
-OpReturnValue %10
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %6 %7
+ %6 = OpLabel
+ OpBranch %5
+ %7 = OpLabel
+ OpBranch %5
+ %5 = OpLabel
+ %10 = OpPhi %int %int_10 %6 %int_20 %7
+ %13 = OpPhi %bool %true %6 %false %7
+ OpReturnValue %10
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, If_Phi_MultipleValue_1) {
auto* func = b.Function("foo", ty.bool_());
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ b.With(i->True(), [&] { //
+ b.ExitIf(i, 10_i, true);
+ });
+ b.With(i->False(), [&] { //
+ b.ExitIf(i, 20_i, false);
+ });
+ b.Return(func, i->Result(1));
+ });
- auto* i = b.If(true);
- i->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
- i->True()->Append(b.ExitIf(i, 10_i, true));
- i->False()->Append(b.ExitIf(i, 20_i, false));
-
- func->Block()->Append(i);
- func->Block()->Append(b.Return(func, i->Result(1)));
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeBool
-%3 = OpTypeFunction %2
-%8 = OpConstantTrue %2
-%9 = OpTypeInt 32 1
-%11 = OpConstant %9 10
-%12 = OpConstant %9 20
-%14 = OpConstantFalse %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpBranch %5
-%7 = OpLabel
-OpBranch %5
-%5 = OpLabel
-%10 = OpPhi %9 %11 %6 %12 %7
-%13 = OpPhi %2 %8 %6 %14 %7
-OpReturnValue %13
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %6 %7
+ %6 = OpLabel
+ OpBranch %5
+ %7 = OpLabel
+ OpBranch %5
+ %5 = OpLabel
+ %10 = OpPhi %int %int_10 %6 %int_20 %7
+ %13 = OpPhi %bool %true %6 %false %7
+ OpReturnValue %13
+ OpFunctionEnd
)");
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index d0d6d24..67efd6a 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -21,159 +21,146 @@
TEST_F(SpvGeneratorImplTest, Loop_BreakIf) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { //
+ b.Continue(loop);
- auto* loop = b.Loop();
+ b.With(loop->Continuing(), [&] { //
+ b.BreakIf(loop, true);
+ });
+ });
+ b.Return(func);
+ });
- loop->Body()->Append(b.Continue(loop));
- loop->Continuing()->Append(b.BreakIf(loop, true));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%10 = OpTypeBool
-%9 = OpConstantTrue %10
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpBranch %7
-%7 = OpLabel
-OpBranchConditional %9 %8 %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %6
+ %6 = OpLabel
+ OpBranchConditional %true %8 %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
// Test that we still emit the continuing block with a back-edge, even when it is unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalBreakInBody) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { //
+ b.ExitLoop(loop);
+ });
+ b.Return(func);
+ });
- auto* loop = b.Loop();
-
- loop->Body()->Append(b.ExitLoop(loop));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpBranch %8
-%7 = OpLabel
-OpBranch %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %8
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalBreakInBody) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* cond_break = b.If(true);
+ b.With(cond_break->True(), [&] { //
+ b.ExitLoop(loop);
+ });
+ b.With(cond_break->False(), [&] { //
+ b.ExitIf(cond_break);
+ });
+ b.Continue(loop);
- auto* loop = b.Loop();
+ b.With(loop->Continuing(), [&] { //
+ b.NextIteration(loop);
+ });
+ });
+ b.Return(func);
+ });
- auto* cond_break = b.If(true);
- cond_break->True()->Append(b.ExitLoop(loop));
- cond_break->False()->Append(b.ExitIf(cond_break));
-
- loop->Body()->Append(cond_break);
- loop->Body()->Append(b.Continue(loop));
- loop->Continuing()->Append(b.NextIteration(loop));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%12 = OpTypeBool
-%11 = OpConstantTrue %12
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpSelectionMerge %9 None
-OpBranchConditional %11 %10 %9
-%10 = OpLabel
-OpBranch %8
-%9 = OpLabel
-OpBranch %7
-%7 = OpLabel
-OpBranch %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpSelectionMerge %9 None
+ OpBranchConditional %true %10 %9
+ %10 = OpLabel
+ OpBranch %8
+ %9 = OpLabel
+ OpBranch %6
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalContinueInBody) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* cond_break = b.If(true);
+ b.With(cond_break->True(), [&] { //
+ b.Continue(loop);
+ });
+ b.With(cond_break->False(), [&] { //
+ b.ExitIf(cond_break);
+ });
+ b.ExitLoop(loop);
- auto* loop = b.Loop();
+ b.With(loop->Continuing(), [&] { //
+ b.NextIteration(loop);
+ });
+ });
+ b.Return(func);
+ });
- auto* cond_break = b.If(true);
- cond_break->True()->Append(b.Continue(loop));
- cond_break->False()->Append(b.ExitIf(cond_break));
-
- loop->Body()->Append(cond_break);
- loop->Body()->Append(b.ExitLoop(loop));
- loop->Continuing()->Append(b.NextIteration(loop));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%12 = OpTypeBool
-%11 = OpConstantTrue %12
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpSelectionMerge %9 None
-OpBranchConditional %11 %10 %9
-%10 = OpLabel
-OpBranch %7
-%9 = OpLabel
-OpBranch %8
-%7 = OpLabel
-OpBranch %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpSelectionMerge %9 None
+ OpBranchConditional %true %10 %9
+ %10 = OpLabel
+ OpBranch %6
+ %9 = OpLabel
+ OpBranch %8
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
@@ -181,169 +168,158 @@
// they are unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalReturnInBody) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
- auto* loop = b.Loop();
- loop->Body()->Append(b.Return(func));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Unreachable());
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpReturn
-%7 = OpLabel
-OpBranch %5
-%8 = OpLabel
-OpUnreachable
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %8
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Loop_UseResultFromBodyInContinuing) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* result = b.Equal(ty.bool_(), 1_i, 2_i);
+ b.Continue(loop, result);
- auto* loop = b.Loop();
+ b.With(loop->Continuing(), [&] { //
+ b.BreakIf(loop, result);
+ });
+ });
+ b.Return(func);
+ });
- auto* result = loop->Body()->Append(b.Equal(ty.i32(), 1_i, 2_i));
- loop->Body()->Append(b.Continue(loop, result));
-
- loop->Continuing()->Append(b.BreakIf(loop, result));
-
- func->Block()->Append(loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%10 = OpTypeInt 32 1
-%11 = OpConstant %10 1
-%12 = OpConstant %10 2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-%9 = OpIEqual %10 %11 %12
-OpBranch %7
-%7 = OpLabel
-OpBranchConditional %9 %8 %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ %9 = OpIEqual %bool %int_1 %int_2
+ OpBranch %6
+ %6 = OpLabel
+ OpBranchConditional %9 %8 %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInBody) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* outer_loop = b.Loop();
+ b.With(outer_loop->Body(), [&] {
+ auto* inner_loop = b.Loop();
+ b.With(inner_loop->Body(), [&] {
+ b.ExitLoop(inner_loop);
- auto* outer_loop = b.Loop();
- auto* inner_loop = b.Loop();
+ b.With(inner_loop->Continuing(), [&] { //
+ b.NextIteration(inner_loop);
+ });
+ });
+ b.Continue(outer_loop);
- inner_loop->Body()->Append(b.ExitLoop(inner_loop));
- inner_loop->Continuing()->Append(b.NextIteration(inner_loop));
+ b.With(outer_loop->Continuing(),
+ [&] { //
+ b.BreakIf(outer_loop, true);
+ });
+ });
+ b.Return(func);
+ });
- outer_loop->Body()->Append(inner_loop);
- outer_loop->Body()->Append(b.Continue(outer_loop));
- outer_loop->Continuing()->Append(b.BreakIf(outer_loop, true));
-
- func->Block()->Append(outer_loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%14 = OpTypeBool
-%13 = OpConstantTrue %14
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpBranch %9
-%9 = OpLabel
-OpLoopMerge %12 %11 None
-OpBranch %10
-%10 = OpLabel
-OpBranch %12
-%11 = OpLabel
-OpBranch %9
-%12 = OpLabel
-OpBranch %7
-%7 = OpLabel
-OpBranchConditional %13 %8 %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ OpLoopMerge %12 %10 None
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %12
+ %10 = OpLabel
+ OpBranch %11
+ %12 = OpLabel
+ OpBranch %6
+ %6 = OpLabel
+ OpBranchConditional %true %8 %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInContinuing) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* outer_loop = b.Loop();
+ b.With(outer_loop->Body(), [&] {
+ b.Continue(outer_loop);
- auto* outer_loop = b.Loop();
- auto* inner_loop = b.Loop();
+ b.With(outer_loop->Continuing(), [&] {
+ auto* inner_loop = b.Loop();
+ b.With(inner_loop->Body(), [&] {
+ b.Continue(inner_loop);
- inner_loop->Body()->Append(b.Continue(inner_loop));
- inner_loop->Continuing()->Append(b.BreakIf(inner_loop, true));
+ b.With(inner_loop->Continuing(), [&] { //
+ b.BreakIf(inner_loop, true);
+ });
+ });
+ b.BreakIf(outer_loop, true);
+ });
+ });
+ b.Return(func);
+ });
- outer_loop->Body()->Append(b.Continue(outer_loop));
- outer_loop->Continuing()->Append(inner_loop);
- outer_loop->Continuing()->Append(b.BreakIf(outer_loop, true));
-
- func->Block()->Append(outer_loop);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%14 = OpTypeBool
-%13 = OpConstantTrue %14
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpLoopMerge %8 %7 None
-OpBranch %6
-%6 = OpLabel
-OpBranch %7
-%7 = OpLabel
-OpBranch %9
-%9 = OpLabel
-OpLoopMerge %12 %11 None
-OpBranch %10
-%10 = OpLabel
-OpBranch %11
-%11 = OpLabel
-OpBranchConditional %13 %12 %9
-%12 = OpLabel
-OpBranchConditional %13 %8 %5
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpBranch %6
+ %6 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ OpLoopMerge %12 %10 None
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpBranchConditional %true %12 %11
+ %12 = OpLabel
+ OpBranchConditional %true %8 %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
@@ -353,7 +329,9 @@
b.With(func->Block(), [&] {
auto* loop = b.Loop();
- b.With(loop->Initializer(), [&] { b.NextIteration(loop, 1_i, false); });
+ b.With(loop->Initializer(), [&] { //
+ b.NextIteration(loop, 1_i, false);
+ });
auto* loop_param = b.BlockParam(ty.i32());
loop->Body()->SetParams({loop_param});
@@ -373,35 +351,24 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%10 = OpTypeInt 32 1
-%12 = OpConstant %10 1
-%16 = OpTypeBool
-%17 = OpConstant %10 5
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpBranch %6
-%6 = OpLabel
-%11 = OpPhi %10 %12 %5 %13 %8
-OpLoopMerge %9 %8 None
-OpBranch %7
-%7 = OpLabel
-%14 = OpIAdd %10 %11 %12
-OpBranch %8
-%8 = OpLabel
-%13 = OpPhi %10 %14 %6
-%15 = OpSGreaterThan %16 %13 %17
-OpBranchConditional %15 %9 %6
-%9 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %5 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %int %int_1 %5 %13 %7
+ OpLoopMerge %9 %7 None
+ OpBranch %6
+ %6 = OpLabel
+ %14 = OpIAdd %int %11 %int_1
+ OpBranch %7
+ %7 = OpLabel
+ %13 = OpPhi %int %14 %6
+ %15 = OpSGreaterThan %bool %13 %int_5
+ OpBranchConditional %15 %9 %8
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
@@ -411,7 +378,9 @@
b.With(func->Block(), [&] {
auto* loop = b.Loop();
- b.With(loop->Initializer(), [&] { b.NextIteration(loop, 1_i, false); });
+ b.With(loop->Initializer(), [&] { //
+ b.NextIteration(loop, 1_i, false);
+ });
auto* loop_param_a = b.BlockParam(ty.i32());
auto* loop_param_b = b.BlockParam(ty.bool_());
@@ -434,39 +403,27 @@
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%10 = OpTypeInt 32 1
-%12 = OpConstant %10 1
-%14 = OpTypeBool
-%16 = OpConstantFalse %14
-%21 = OpConstant %10 5
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpBranch %5
-%5 = OpLabel
-OpBranch %6
-%6 = OpLabel
-%11 = OpPhi %10 %12 %5 %13 %8
-%15 = OpPhi %14 %16 %5 %17 %8
-OpLoopMerge %9 %8 None
-OpBranch %7
-%7 = OpLabel
-%18 = OpIAdd %10 %11 %12
-OpBranch %8
-%8 = OpLabel
-%13 = OpPhi %10 %18 %6
-%19 = OpPhi %14 %15 %6
-%20 = OpSGreaterThan %14 %13 %21
-%17 = OpLogicalEqual %14 %19 %16
-OpBranchConditional %20 %9 %6
-%9 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %5 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %int %int_1 %5 %13 %7
+ %15 = OpPhi %bool %false %5 %17 %7
+ OpLoopMerge %9 %7 None
+ OpBranch %6
+ %6 = OpLabel
+ %18 = OpIAdd %int %11 %int_1
+ OpBranch %7
+ %7 = OpLabel
+ %13 = OpPhi %int %18 %6
+ %19 = OpPhi %bool %15 %6
+ %20 = OpSGreaterThan %bool %13 %int_5
+ %17 = OpLogicalEqual %bool %19 %false
+ OpBranchConditional %20 %9 %8
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index 396e0fe..d851066 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -21,368 +21,351 @@
TEST_F(SpvGeneratorImplTest, Switch_Basic) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(42_i);
- auto* swtch = b.Switch(42_i);
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ b.With(def_case, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
- def_case->Append(b.ExitSwitch(swtch));
+ b.Return(func);
+ });
- func->Block()->Append(swtch);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %8 None
-OpSwitch %6 %5
-%5 = OpLabel
-OpBranch %8
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %8 None
+ OpSwitch %int_42 %5
+ %5 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleCases) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(42_i);
- auto* swtch = b.Switch(42_i);
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ b.With(case_a, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
- case_a->Append(b.ExitSwitch(swtch));
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(swtch));
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ b.With(def_case, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
- def_case->Append(b.ExitSwitch(swtch));
+ b.Return(func);
+ });
- func->Block()->Append(swtch);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %10 None
-OpSwitch %6 %5 1 %8 2 %9
-%8 = OpLabel
-OpBranch %10
-%9 = OpLabel
-OpBranch %10
-%5 = OpLabel
-OpBranch %10
-%10 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %10 None
+ OpSwitch %int_42 %5 1 %8 2 %9
+ %8 = OpLabel
+ OpBranch %10
+ %9 = OpLabel
+ OpBranch %10
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleSelectorsPerCase) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(42_i);
- auto* swtch = b.Switch(42_i);
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{b.Constant(3_i)}});
+ b.With(case_a, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{b.Constant(3_i)}});
- case_a->Append(b.ExitSwitch(swtch));
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
+ ir::Switch::CaseSelector{b.Constant(4_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
- ir::Switch::CaseSelector{b.Constant(4_i)}});
- case_b->Append(b.ExitSwitch(swtch));
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
+ ir::Switch::CaseSelector()});
+ b.With(def_case, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
- ir::Switch::CaseSelector()});
- def_case->Append(b.ExitSwitch(swtch));
+ b.Return(func);
+ });
- func->Block()->Append(swtch);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %10 None
-OpSwitch %6 %5 1 %8 3 %8 2 %9 4 %9 5 %5
-%8 = OpLabel
-OpBranch %10
-%9 = OpLabel
-OpBranch %10
-%5 = OpLabel
-OpBranch %10
-%10 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %10 None
+ OpSwitch %int_42 %5 1 %8 3 %8 2 %9 4 %9 5 %5
+ %8 = OpLabel
+ OpBranch %10
+ %9 = OpLabel
+ OpBranch %10
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_AllCasesReturn) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(42_i);
- auto* swtch = b.Switch(42_i);
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ b.With(case_a, [&] { //
+ b.Return(func);
+ });
- auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
- case_a->Append(b.Return(func));
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.Return(func);
+ });
- auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.Return(func));
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ b.With(def_case, [&] { //
+ b.Return(func);
+ });
- auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
- def_case->Append(b.Return(func));
+ b.Unreachable();
+ });
- func->Block()->Append(swtch);
- func->Block()->Append(b.Unreachable());
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %10 None
-OpSwitch %6 %5 1 %8 2 %9
-%8 = OpLabel
-OpReturn
-%9 = OpLabel
-OpReturn
-%5 = OpLabel
-OpReturn
-%10 = OpLabel
-OpUnreachable
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %10 None
+ OpSwitch %int_42 %5 1 %8 2 %9
+ %8 = OpLabel
+ OpBranch %10
+ %9 = OpLabel
+ OpBranch %10
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_ConditionalBreak) {
auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(42_i);
- auto* swtch = b.Switch(42_i);
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ b.With(case_a, [&] {
+ auto* cond_break = b.If(true);
+ b.With(cond_break->True(), [&] { //
+ b.ExitSwitch(swtch);
+ });
+ b.With(cond_break->False(), [&] { //
+ b.ExitIf(cond_break);
+ });
- auto* cond_break = b.If(true);
- cond_break->True()->Append(b.ExitSwitch(swtch));
- cond_break->False()->Append(b.ExitIf(cond_break));
+ b.Return(func);
+ });
- auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
- case_a->Append(cond_break);
- case_a->Append(b.Return(func));
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ b.With(def_case, [&] { //
+ b.ExitSwitch(swtch);
+ });
- auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
- def_case->Append(b.ExitSwitch(swtch));
+ b.Return(func);
+ });
- func->Block()->Append(swtch);
- func->Block()->Append(b.Return(func));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%13 = OpTypeBool
-%12 = OpConstantTrue %13
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %9 None
-OpSwitch %6 %5 1 %8
-%8 = OpLabel
-OpSelectionMerge %10 None
-OpBranchConditional %12 %11 %10
-%11 = OpLabel
-OpBranch %9
-%10 = OpLabel
-OpReturn
-%5 = OpLabel
-OpBranch %9
-%9 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %9 None
+ OpSwitch %int_42 %5 1 %8
+ %8 = OpLabel
+ OpSelectionMerge %10 None
+ OpBranchConditional %true %11 %10
+ %11 = OpLabel
+ OpBranch %9
+ %10 = OpLabel
+ OpBranch %9
+ %5 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* s = b.Switch(42_i);
+ s->SetResults(b.InstructionResult(ty.i32()));
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ b.With(case_a, [&] { //
+ b.ExitSwitch(s, 10_i);
+ });
- auto* s = b.Switch(42_i);
- s->SetResults(b.InstructionResult(ty.i32()));
- auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.ExitSwitch(s, 10_i));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(s, 20_i);
+ });
- auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, 20_i));
+ b.Return(func, s);
+ });
- func->Block()->Append(s);
- func->Block()->Append(b.Return(func, s));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%6 = OpConstant %2 42
-%10 = OpConstant %2 10
-%11 = OpConstant %2 20
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %8 None
-OpSwitch %6 %5 1 %5 2 %7
-%5 = OpLabel
-OpBranch %8
-%7 = OpLabel
-OpBranch %8
-%8 = OpLabel
-%9 = OpPhi %2 %10 %5 %11 %7
-OpReturnValue %9
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %8 None
+ OpSwitch %int_42 %5 1 %5 2 %7
+ %5 = OpLabel
+ OpBranch %8
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %9 = OpPhi %int %int_10 %5 %int_20 %7
+ OpReturnValue %9
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue_CaseReturn) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* s = b.Switch(42_i);
+ s->SetResults(b.InstructionResult(ty.i32()));
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ b.With(case_a, [&] { //
+ b.Return(func, 10_i);
+ });
- auto* s = b.Switch(42_i);
- s->SetResults(b.InstructionResult(ty.i32()));
- auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.Return(func, 10_i));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(s, 20_i);
+ });
- auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, 20_i));
+ b.Return(func, s);
+ });
- func->Block()->Append(s);
- func->Block()->Append(b.Return(func, s));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%6 = OpConstant %2 42
-%9 = OpConstant %2 10
-%11 = OpConstant %2 20
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %8 None
-OpSwitch %6 %5 1 %5 2 %7
-%5 = OpLabel
-OpReturnValue %9
-%7 = OpLabel
-OpBranch %8
-%8 = OpLabel
-%10 = OpPhi %2 %11 %7
-OpReturnValue %10
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+%return_value = OpVariable %_ptr_Function_int Function
+%continue_execution = OpVariable %_ptr_Function_bool Function
+ OpStore %continue_execution %true
+ OpSelectionMerge %14 None
+ OpSwitch %int_42 %11 1 %11 2 %13
+ %11 = OpLabel
+ OpStore %continue_execution %false
+ OpStore %return_value %int_10
+ OpBranch %14
+ %13 = OpLabel
+ OpBranch %14
+ %14 = OpLabel
+ %17 = OpPhi %int %18 %11 %int_20 %13
+ %20 = OpLoad %bool %continue_execution
+ OpSelectionMerge %21 None
+ OpBranchConditional %20 %22 %21
+ %22 = OpLabel
+ OpStore %return_value %17
+ OpBranch %21
+ %21 = OpLabel
+ %23 = OpLoad %int %return_value
+ OpReturnValue %23
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_MultipleValue_0) {
auto* func = b.Function("foo", ty.i32());
+ b.With(func->Block(), [&] {
+ auto* s = b.Switch(42_i);
+ s->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ b.With(case_a, [&] { //
+ b.ExitSwitch(s, 10_i, true);
+ });
- auto* s = b.Switch(42_i);
- s->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
- auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.ExitSwitch(s, 10_i, true));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(s, 20_i, false);
+ });
- auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, 20_i, false));
+ b.Return(func, s->Result(0));
+ });
- func->Block()->Append(s);
- func->Block()->Append(b.Return(func, s->Result(0)));
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%6 = OpConstant %2 42
-%10 = OpConstant %2 10
-%11 = OpConstant %2 20
-%12 = OpTypeBool
-%14 = OpConstantTrue %12
-%15 = OpConstantFalse %12
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %8 None
-OpSwitch %6 %5 1 %5 2 %7
-%5 = OpLabel
-OpBranch %8
-%7 = OpLabel
-OpBranch %8
-%8 = OpLabel
-%9 = OpPhi %2 %10 %5 %11 %7
-%13 = OpPhi %12 %14 %5 %15 %7
-OpReturnValue %9
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %8 None
+ OpSwitch %int_42 %5 1 %5 2 %7
+ %5 = OpLabel
+ OpBranch %8
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %9 = OpPhi %int %int_10 %5 %int_20 %7
+ %13 = OpPhi %bool %true %5 %false %7
+ OpReturnValue %9
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_MultipleValue_1) {
auto* func = b.Function("foo", ty.bool_());
+ b.With(func->Block(), [&] {
+ auto* s = b.Switch(b.Constant(42_i));
+ s->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ b.With(case_a, [&] { //
+ b.ExitSwitch(s, 10_i, true);
+ });
- auto* s = b.Switch(b.Constant(42_i));
- s->SetResults(b.InstructionResult(ty.i32()), b.InstructionResult(ty.bool_()));
- auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.ExitSwitch(s, 10_i, true));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ b.With(case_b, [&] { //
+ b.ExitSwitch(s, 20_i, false);
+ });
- auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, 20_i, false));
+ b.Return(func, s->Result(1));
+ });
- func->Block()->Append(s);
- func->Block()->Append(b.Return(func, s->Result(1)));
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeBool
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpConstant %7 42
-%11 = OpConstant %7 10
-%12 = OpConstant %7 20
-%14 = OpConstantTrue %2
-%15 = OpConstantFalse %2
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %9 None
-OpSwitch %6 %5 1 %5 2 %8
-%5 = OpLabel
-OpBranch %9
-%8 = OpLabel
-OpBranch %9
-%9 = OpLabel
-%10 = OpPhi %7 %11 %5 %12 %8
-%13 = OpPhi %2 %14 %5 %15 %8
-OpReturnValue %13
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %4 = OpLabel
+ OpSelectionMerge %9 None
+ OpSwitch %int_42 %5 1 %5 2 %8
+ %5 = OpLabel
+ OpBranch %9
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ %10 = OpPhi %int %int_10 %5 %int_20 %8
+ %13 = OpPhi %bool %true %5 %false %8
+ OpReturnValue %13
+ OpFunctionEnd
)");
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_swizzle_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_swizzle_test.cc
new file mode 100644
index 0000000..491f0da
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_swizzle_test.cc
@@ -0,0 +1,80 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+namespace tint::writer::spirv {
+namespace {
+
+using namespace tint::builtin::fluent_types; // NOLINT
+using namespace tint::number_suffixes; // NOLINT
+
+TEST_F(SpvGeneratorImplTest, Swizzle_TwoElements) {
+ auto* vec = b.FunctionParam("vec", ty.vec4<i32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec});
+ b.With(func->Block(), [&] {
+ auto* result = b.Swizzle(ty.vec2<i32>(), vec, {3_u, 2_u});
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorShuffle %v2int %vec %vec 3 2");
+}
+
+TEST_F(SpvGeneratorImplTest, Swizzle_ThreeElements) {
+ auto* vec = b.FunctionParam("vec", ty.vec4<i32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec});
+ b.With(func->Block(), [&] {
+ auto* result = b.Swizzle(ty.vec3<i32>(), vec, {3_u, 2_u, 1_u});
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorShuffle %v3int %vec %vec 3 2 1");
+}
+
+TEST_F(SpvGeneratorImplTest, Swizzle_FourElements) {
+ auto* vec = b.FunctionParam("vec", ty.vec4<i32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec});
+ b.With(func->Block(), [&] {
+ auto* result = b.Swizzle(ty.vec4<i32>(), vec, {3_u, 2_u, 1_u, 0u});
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorShuffle %v4int %vec %vec 3 2 1 0");
+}
+
+TEST_F(SpvGeneratorImplTest, Swizzle_RepeatedElements) {
+ auto* vec = b.FunctionParam("vec", ty.vec2<i32>());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec});
+ b.With(func->Block(), [&] {
+ auto* result = b.Swizzle(ty.vec4<i32>(), vec, {1_u, 3_u, 1_u, 3_u});
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = OpVectorShuffle %v4int %vec %vec 1 3 1 3");
+}
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
index 9ec2a47..08d15bd 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
@@ -30,36 +30,48 @@
}
TEST_F(SpvGeneratorImplTest, Unreachable) {
- auto* func = b.Function("foo", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* ifelse = b.If(true);
+ b.With(ifelse->True(), [&] { //
+ b.Continue(loop);
+ });
+ b.With(ifelse->False(), [&] { //
+ b.Continue(loop);
+ });
+ b.Unreachable();
- auto* i = b.If(true);
- i->True()->Append(b.Return(func, 10_i));
- i->False()->Append(b.Return(func, 20_i));
+ b.With(loop->Continuing(), [&] { //
+ b.NextIteration(loop);
+ });
+ });
+ b.Return(func);
+ });
- func->Block()->Append(i);
- func->Block()->Append(b.Unreachable());
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeInt 32 1
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%10 = OpConstant %2 10
-%11 = OpConstant %2 20
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpReturnValue %10
-%7 = OpLabel
-OpReturnValue %11
-%5 = OpLabel
-OpUnreachable
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %foo = OpFunction %void None %3
+ %4 = OpLabel
+ OpBranch %7
+ %7 = OpLabel
+ OpLoopMerge %8 %6 None
+ OpBranch %5
+ %5 = OpLabel
+ OpSelectionMerge %9 None
+ OpBranchConditional %true %10 %11
+ %10 = OpLabel
+ OpBranch %6
+ %11 = OpLabel
+ OpBranch %6
+ %9 = OpLabel
+ OpUnreachable
+ %6 = OpLabel
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
index 7cd64e4..4284884 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc
@@ -13,9 +13,14 @@
// limitations under the License.
#include "src/tint/type/bool.h"
+#include "src/tint/type/depth_multisampled_texture.h"
+#include "src/tint/type/depth_texture.h"
#include "src/tint/type/f16.h"
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
+#include "src/tint/type/multisampled_texture.h"
+#include "src/tint/type/sampled_texture.h"
+#include "src/tint/type/storage_texture.h"
#include "src/tint/type/type.h"
#include "src/tint/type/u32.h"
#include "src/tint/type/void.h"
@@ -25,164 +30,140 @@
namespace {
TEST_F(SpvGeneratorImplTest, Type_Void) {
- auto id = generator_.Type(ty.void_());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeVoid\n");
+ generator_.Type(ty.void_());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%void = OpTypeVoid");
}
TEST_F(SpvGeneratorImplTest, Type_Bool) {
- auto id = generator_.Type(ty.bool_());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeBool\n");
+ generator_.Type(ty.bool_());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%bool = OpTypeBool");
}
TEST_F(SpvGeneratorImplTest, Type_I32) {
- auto id = generator_.Type(ty.i32());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 1\n");
+ generator_.Type(ty.i32());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%int = OpTypeInt 32 1");
}
TEST_F(SpvGeneratorImplTest, Type_U32) {
- auto id = generator_.Type(ty.u32());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 0\n");
+ generator_.Type(ty.u32());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%uint = OpTypeInt 32 0");
}
TEST_F(SpvGeneratorImplTest, Type_F32) {
- auto id = generator_.Type(ty.f32());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 32\n");
+ generator_.Type(ty.f32());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%float = OpTypeFloat 32");
}
TEST_F(SpvGeneratorImplTest, Type_F16) {
- auto id = generator_.Type(ty.f16());
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 16\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Capabilities()),
- "OpCapability Float16\n"
- "OpCapability UniformAndStorageBuffer16BitAccess\n"
- "OpCapability StorageBuffer16BitAccess\n"
- "OpCapability StorageInputOutput16\n");
+ generator_.Type(ty.f16());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpCapability Float16");
+ EXPECT_INST("OpCapability UniformAndStorageBuffer16BitAccess");
+ EXPECT_INST("OpCapability StorageBuffer16BitAccess");
+ EXPECT_INST("OpCapability StorageInputOutput16");
+ EXPECT_INST("%half = OpTypeFloat 16");
}
TEST_F(SpvGeneratorImplTest, Type_Vec2i) {
- auto id = generator_.Type(ty.vec2(ty.i32()));
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeInt 32 1\n"
- "%1 = OpTypeVector %2 2\n");
+ generator_.Type(ty.vec2<i32>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v2int = OpTypeVector %int 2");
}
TEST_F(SpvGeneratorImplTest, Type_Vec3u) {
- auto id = generator_.Type(ty.vec3(ty.u32()));
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeInt 32 0\n"
- "%1 = OpTypeVector %2 3\n");
+ generator_.Type(ty.vec3<u32>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v3uint = OpTypeVector %uint 3");
}
TEST_F(SpvGeneratorImplTest, Type_Vec4f) {
- auto id = generator_.Type(ty.vec4(ty.f32()));
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 32\n"
- "%1 = OpTypeVector %2 4\n");
+ generator_.Type(ty.vec4<f32>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v4float = OpTypeVector %float 4");
}
TEST_F(SpvGeneratorImplTest, Type_Vec2h) {
- auto id = generator_.Type(ty.vec2(ty.f16()));
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 16\n"
- "%1 = OpTypeVector %2 2\n");
+ generator_.Type(ty.vec2<f16>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v2half = OpTypeVector %half 2");
}
TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) {
- auto id = generator_.Type(ty.vec4(ty.bool_()));
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeBool\n"
- "%1 = OpTypeVector %2 4\n");
+ generator_.Type(ty.vec4<bool>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v4bool = OpTypeVector %bool 4");
}
TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) {
- auto* vec = ty.mat2x3(ty.f32());
- auto id = generator_.Type(vec);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%3 = OpTypeFloat 32\n"
- "%2 = OpTypeVector %3 3\n"
- "%1 = OpTypeMatrix %2 2\n");
+ generator_.Type(ty.mat2x3(ty.f32()));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%mat2v3float = OpTypeMatrix %v3float 2");
}
TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) {
- auto* vec = ty.mat4x2(ty.f16());
- auto id = generator_.Type(vec);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%3 = OpTypeFloat 16\n"
- "%2 = OpTypeVector %3 2\n"
- "%1 = OpTypeMatrix %2 4\n");
+ generator_.Type(ty.mat4x2(ty.f16()));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%mat4v2half = OpTypeMatrix %v2half 4");
}
TEST_F(SpvGeneratorImplTest, Type_Array_DefaultStride) {
- auto* arr = ty.array(ty.f32(), 4u);
- auto id = generator_.Type(arr);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 32\n"
- "%4 = OpTypeInt 32 0\n"
- "%3 = OpConstant %4 4\n"
- "%1 = OpTypeArray %2 %3\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 4\n");
+ generator_.Type(ty.array<f32, 4>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpDecorate %_arr_float_uint_4 ArrayStride 4");
+ EXPECT_INST("%_arr_float_uint_4 = OpTypeArray %float %uint_4");
}
TEST_F(SpvGeneratorImplTest, Type_Array_ExplicitStride) {
- auto* arr = ty.array(ty.f32(), 4u, 16);
- auto id = generator_.Type(arr);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 32\n"
- "%4 = OpTypeInt 32 0\n"
- "%3 = OpConstant %4 4\n"
- "%1 = OpTypeArray %2 %3\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 16\n");
+ generator_.Type(ty.array<f32, 4>(16));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpDecorate %_arr_float_uint_4 ArrayStride 16");
+ EXPECT_INST("%_arr_float_uint_4 = OpTypeArray %float %uint_4");
}
TEST_F(SpvGeneratorImplTest, Type_Array_NestedArray) {
- auto* arr = ty.array(ty.array(ty.f32(), 64u), 4u);
- auto id = generator_.Type(arr);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%3 = OpTypeFloat 32\n"
- "%5 = OpTypeInt 32 0\n"
- "%4 = OpConstant %5 64\n"
- "%2 = OpTypeArray %3 %4\n"
- "%6 = OpConstant %5 4\n"
- "%1 = OpTypeArray %2 %6\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()),
- "OpDecorate %2 ArrayStride 4\n"
- "OpDecorate %1 ArrayStride 256\n");
+ generator_.Type(ty.array(ty.array<f32, 64u>(), 4u));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpDecorate %_arr_float_uint_64 ArrayStride 4");
+ EXPECT_INST("OpDecorate %_arr__arr_float_uint_64_uint_4 ArrayStride 256");
+ EXPECT_INST("%_arr_float_uint_64 = OpTypeArray %float %uint_64");
+ EXPECT_INST("%_arr__arr_float_uint_64_uint_4 = OpTypeArray %_arr_float_uint_64 %uint_4");
}
TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_DefaultStride) {
- auto* arr = ty.runtime_array(ty.f32());
- auto id = generator_.Type(arr);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 32\n"
- "%1 = OpTypeRuntimeArray %2\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 4\n");
+ generator_.Type(ty.array<f32>());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpDecorate %_runtimearr_float ArrayStride 4");
+ EXPECT_INST("%_runtimearr_float = OpTypeRuntimeArray %float");
}
TEST_F(SpvGeneratorImplTest, Type_RuntimeArray_ExplicitStride) {
- auto* arr = ty.runtime_array(ty.f32(), 16);
- auto id = generator_.Type(arr);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(),
- "%2 = OpTypeFloat 32\n"
- "%1 = OpTypeRuntimeArray %2\n");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), "OpDecorate %1 ArrayStride 16\n");
+ generator_.Type(ty.array<f32>(16));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpDecorate %_runtimearr_float ArrayStride 16");
+ EXPECT_INST("%_runtimearr_float = OpTypeRuntimeArray %float");
}
TEST_F(SpvGeneratorImplTest, Type_Struct) {
@@ -191,20 +172,15 @@
{mod.symbols.Register("a"), ty.f32()},
{mod.symbols.Register("b"), ty.vec4<i32>()},
});
- auto id = generator_.Type(str);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32
-%4 = OpTypeInt 32 1
-%3 = OpTypeVector %4 4
-%1 = OpTypeStruct %2 %3
-)");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), R"(OpMemberDecorate %1 0 Offset 0
-OpMemberDecorate %1 1 Offset 16
-)");
- EXPECT_EQ(DumpInstructions(generator_.Module().Debug()), R"(OpMemberName %1 0 "a"
-OpMemberName %1 1 "b"
-OpName %1 "MyStruct"
-)");
+ generator_.Type(str);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpMemberName %MyStruct 0 \"a\"");
+ EXPECT_INST("OpMemberName %MyStruct 1 \"b\"");
+ EXPECT_INST("OpName %MyStruct \"MyStruct\"");
+ EXPECT_INST("OpMemberDecorate %MyStruct 0 Offset 0");
+ EXPECT_INST("OpMemberDecorate %MyStruct 1 Offset 16");
+ EXPECT_INST("%MyStruct = OpTypeStruct %float %v4int");
}
TEST_F(SpvGeneratorImplTest, Type_Struct_MatrixLayout) {
@@ -215,56 +191,206 @@
// Matrices nested inside arrays need layout decorations on the struct member too.
{mod.symbols.Register("arr"), ty.array(ty.array(ty.mat2x4<f16>(), 4), 4)},
});
- auto id = generator_.Type(str);
- EXPECT_EQ(id, 1u);
- EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32
-%3 = OpTypeVector %4 3
-%2 = OpTypeMatrix %3 3
-%9 = OpTypeFloat 16
-%8 = OpTypeVector %9 4
-%7 = OpTypeMatrix %8 2
-%11 = OpTypeInt 32 0
-%10 = OpConstant %11 4
-%6 = OpTypeArray %7 %10
-%5 = OpTypeArray %6 %10
-%1 = OpTypeStruct %2 %5
-)");
- EXPECT_EQ(DumpInstructions(generator_.Module().Annots()), R"(OpMemberDecorate %1 0 Offset 0
-OpMemberDecorate %1 0 ColMajor
-OpMemberDecorate %1 0 MatrixStride 16
-OpDecorate %6 ArrayStride 16
-OpDecorate %5 ArrayStride 64
-OpMemberDecorate %1 1 Offset 48
-OpMemberDecorate %1 1 ColMajor
-OpMemberDecorate %1 1 MatrixStride 8
-)");
- EXPECT_EQ(DumpInstructions(generator_.Module().Debug()), R"(OpMemberName %1 0 "m"
-OpMemberName %1 1 "arr"
-OpName %1 "MyStruct"
-)");
+ generator_.Type(str);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpMemberDecorate %MyStruct 0 ColMajor");
+ EXPECT_INST("OpMemberDecorate %MyStruct 0 MatrixStride 16");
+ EXPECT_INST("OpMemberDecorate %MyStruct 1 ColMajor");
+ EXPECT_INST("OpMemberDecorate %MyStruct 1 MatrixStride 8");
+ EXPECT_INST("%MyStruct = OpTypeStruct %mat3v3float %_arr__arr_mat2v4half_uint_4_uint_4");
}
+TEST_F(SpvGeneratorImplTest, Type_Sampler) {
+ generator_.Type(ty.sampler());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpTypeSampler");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_SamplerComparison) {
+ generator_.Type(ty.comparison_sampler());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpTypeSampler");
+}
+
+TEST_F(SpvGeneratorImplTest, Type_Samplers_Dedup) {
+ auto id = generator_.Type(ty.sampler());
+ EXPECT_EQ(generator_.Type(ty.comparison_sampler()), id);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+}
+
+using Dim = type::TextureDimension;
+struct TextureCase {
+ std::string result;
+ Dim dim;
+ TestElementType format = kF32;
+};
+
+using Type_SampledTexture = SpvGeneratorImplTestWithParam<TextureCase>;
+TEST_P(Type_SampledTexture, Emit) {
+ auto params = GetParam();
+ generator_.Type(ty.Get<type::SampledTexture>(params.dim, MakeScalarType(params.format)));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.result);
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Type_SampledTexture,
+ testing::Values(
+ TextureCase{"%1 = OpTypeImage %float 1D 0 0 0 1 Unknown", Dim::k1d, kF32},
+ TextureCase{"%1 = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d, kF32},
+ TextureCase{"%1 = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray, kF32},
+ TextureCase{"%1 = OpTypeImage %float 3D 0 0 0 1 Unknown", Dim::k3d, kF32},
+ TextureCase{"%1 = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube, kF32},
+ TextureCase{"%1 = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray, kF32},
+ TextureCase{"%1 = OpTypeImage %int 1D 0 0 0 1 Unknown", Dim::k1d, kI32},
+ TextureCase{"%1 = OpTypeImage %int 2D 0 0 0 1 Unknown", Dim::k2d, kI32},
+ TextureCase{"%1 = OpTypeImage %int 2D 0 1 0 1 Unknown", Dim::k2dArray, kI32},
+ TextureCase{"%1 = OpTypeImage %int 3D 0 0 0 1 Unknown", Dim::k3d, kI32},
+ TextureCase{"%1 = OpTypeImage %int Cube 0 0 0 1 Unknown", Dim::kCube, kI32},
+ TextureCase{"%1 = OpTypeImage %int Cube 0 1 0 1 Unknown", Dim::kCubeArray, kI32},
+ TextureCase{"%1 = OpTypeImage %uint 1D 0 0 0 1 Unknown", Dim::k1d, kU32},
+ TextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 1 Unknown", Dim::k2d, kU32},
+ TextureCase{"%1 = OpTypeImage %uint 2D 0 1 0 1 Unknown", Dim::k2dArray, kU32},
+ TextureCase{"%1 = OpTypeImage %uint 3D 0 0 0 1 Unknown", Dim::k3d, kU32},
+ TextureCase{"%1 = OpTypeImage %uint Cube 0 0 0 1 Unknown", Dim::kCube, kU32},
+ TextureCase{"%1 = OpTypeImage %uint Cube 0 1 0 1 Unknown", Dim::kCubeArray, kU32}));
+
+using Type_MultisampledTexture = SpvGeneratorImplTestWithParam<TextureCase>;
+TEST_P(Type_MultisampledTexture, Emit) {
+ auto params = GetParam();
+ generator_.Type(ty.Get<type::MultisampledTexture>(params.dim, MakeScalarType(params.format)));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.result);
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Type_MultisampledTexture,
+ testing::Values(TextureCase{"%1 = OpTypeImage %float 2D 0 0 1 1 Unknown", Dim::k2d, kF32},
+ TextureCase{"%1 = OpTypeImage %int 2D 0 0 1 1 Unknown", Dim::k2d, kI32},
+ TextureCase{"%1 = OpTypeImage %uint 2D 0 0 1 1 Unknown", Dim::k2d, kU32}));
+
+using Type_DepthTexture = SpvGeneratorImplTestWithParam<TextureCase>;
+TEST_P(Type_DepthTexture, Emit) {
+ auto params = GetParam();
+ generator_.Type(ty.Get<type::DepthTexture>(params.dim));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.result);
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest,
+ Type_DepthTexture,
+ testing::Values(TextureCase{"%1 = OpTypeImage %float 2D 0 0 0 1 Unknown", Dim::k2d},
+ TextureCase{"%1 = OpTypeImage %float 2D 0 1 0 1 Unknown", Dim::k2dArray},
+ TextureCase{"%1 = OpTypeImage %float Cube 0 0 0 1 Unknown", Dim::kCube},
+ TextureCase{"%1 = OpTypeImage %float Cube 0 1 0 1 Unknown", Dim::kCubeArray}));
+
+TEST_F(SpvGeneratorImplTest, Type_DepthMultiSampledTexture) {
+ generator_.Type(ty.Get<type::DepthMultisampledTexture>(Dim::k2d));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%1 = OpTypeImage %float 2D 0 0 1 1 Unknown");
+}
+
+using Format = builtin::TexelFormat;
+struct StorageTextureCase {
+ std::string result;
+ Dim dim;
+ Format format;
+};
+using Type_StorageTexture = SpvGeneratorImplTestWithParam<StorageTextureCase>;
+TEST_P(Type_StorageTexture, Emit) {
+ auto params = GetParam();
+ generator_.Type(
+ ty.Get<type::StorageTexture>(params.dim, params.format, builtin::Access::kWrite,
+ type::StorageTexture::SubtypeFor(params.format, mod.Types())));
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(params.result);
+ if (params.format == builtin::TexelFormat::kRg32Uint ||
+ params.format == builtin::TexelFormat::kRg32Sint ||
+ params.format == builtin::TexelFormat::kRg32Float) {
+ EXPECT_INST("OpCapability StorageImageExtendedFormats");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
+ Type_StorageTexture,
+ testing::Values(
+ // Test all the dimensions with a single format.
+ StorageTextureCase{"%1 = OpTypeImage %float 1D 0 0 0 2 R32f", //
+ Dim::k1d, Format::kR32Float},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 R32f", //
+ Dim::k2d, Format::kR32Float},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 1 0 2 R32f", //
+ Dim::k2dArray, Format::kR32Float},
+ StorageTextureCase{"%1 = OpTypeImage %float 3D 0 0 0 2 R32f", //
+ Dim::k3d, Format::kR32Float},
+
+ // Test all the formats with 2D.
+ // TODO(jrprice): Enable this format when we polyfill it.
+ // StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Bgra8Unorm",
+ // Dim::k2d, Format::kBgra8Unorm},
+ StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 R32i", //
+ Dim::k2d, Format::kR32Sint},
+ StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 R32u", //
+ Dim::k2d, Format::kR32Uint},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rg32f", //
+ Dim::k2d, Format::kRg32Float},
+ StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rg32i", //
+ Dim::k2d, Format::kRg32Sint},
+ StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rg32ui", //
+ Dim::k2d, Format::kRg32Uint},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba16f", //
+ Dim::k2d, Format::kRgba16Float},
+ StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba16i", //
+ Dim::k2d, Format::kRgba16Sint},
+ StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba16ui", //
+ Dim::k2d, Format::kRgba16Uint},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba32f", //
+ Dim::k2d, Format::kRgba32Float},
+ StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba32i", //
+ Dim::k2d, Format::kRgba32Sint},
+ StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba32ui", //
+ Dim::k2d, Format::kRgba32Uint},
+ StorageTextureCase{"%1 = OpTypeImage %int 2D 0 0 0 2 Rgba8i", //
+ Dim::k2d, Format::kRgba8Sint},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba8Snorm", //
+ Dim::k2d, Format::kRgba8Snorm},
+ StorageTextureCase{"%1 = OpTypeImage %uint 2D 0 0 0 2 Rgba8ui", //
+ Dim::k2d, Format::kRgba8Uint},
+ StorageTextureCase{"%1 = OpTypeImage %float 2D 0 0 0 2 Rgba8", //
+ Dim::k2d, Format::kRgba8Unorm}));
+
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpvGeneratorImplTest, Type_Multiple) {
- EXPECT_EQ(generator_.Type(ty.i32()), 1u);
- EXPECT_EQ(generator_.Type(ty.u32()), 2u);
- EXPECT_EQ(generator_.Type(ty.f32()), 3u);
- EXPECT_EQ(generator_.Type(ty.f16()), 4u);
- EXPECT_EQ(DumpTypes(), R"(%1 = OpTypeInt 32 1
-%2 = OpTypeInt 32 0
-%3 = OpTypeFloat 32
-%4 = OpTypeFloat 16
+ generator_.Type(ty.i32());
+ generator_.Type(ty.u32());
+ generator_.Type(ty.f32());
+ generator_.Type(ty.f16());
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %int = OpTypeInt 32 1
+ %uint = OpTypeInt 32 0
+ %float = OpTypeFloat 32
+ %half = OpTypeFloat 16
)");
}
// Test that we do not emit the same type more than once.
TEST_F(SpvGeneratorImplTest, Type_Deduplicate) {
- auto* i32 = ty.i32();
- EXPECT_EQ(generator_.Type(i32), 1u);
- EXPECT_EQ(generator_.Type(i32), 1u);
- EXPECT_EQ(generator_.Type(i32), 1u);
- EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 1\n");
+ auto id = generator_.Type(ty.i32());
+ EXPECT_EQ(generator_.Type(ty.i32()), id);
+ EXPECT_EQ(generator_.Type(ty.i32()), id);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
}
} // namespace
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_unary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_unary_test.cc
new file mode 100644
index 0000000..4d9320c
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_unary_test.cc
@@ -0,0 +1,77 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+#include "src/tint/ir/unary.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+/// A parameterized test case.
+struct UnaryTestCase {
+ /// The element type to test.
+ TestElementType type;
+ /// The unary operation.
+ enum ir::Unary::Kind kind;
+ /// The expected SPIR-V instruction.
+ std::string spirv_inst;
+ /// The expected SPIR-V result type name.
+ std::string spirv_type_name;
+};
+
+using Arithmetic = SpvGeneratorImplTestWithParam<UnaryTestCase>;
+TEST_P(Arithmetic, Scalar) {
+ auto params = GetParam();
+
+ auto* arg = b.FunctionParam("arg", MakeScalarType(params.type));
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Unary(params.kind, MakeScalarType(params.type), arg);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %" + params.spirv_type_name + " %arg");
+}
+TEST_P(Arithmetic, Vector) {
+ auto params = GetParam();
+
+ auto* arg = b.FunctionParam("arg", MakeVectorType(params.type));
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({arg});
+ b.With(func->Block(), [&] {
+ auto* result = b.Unary(params.kind, MakeVectorType(params.type), arg);
+ b.Return(func);
+ mod.SetName(result, "result");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%result = " + params.spirv_inst + " %v2" + params.spirv_type_name + " %arg");
+}
+INSTANTIATE_TEST_SUITE_P(
+ SpvGeneratorImplTest_Unary,
+ Arithmetic,
+ testing::Values(UnaryTestCase{kI32, ir::Unary::Kind::kComplement, "OpNot", "int"},
+ UnaryTestCase{kU32, ir::Unary::Kind::kComplement, "OpNot", "uint"},
+ UnaryTestCase{kI32, ir::Unary::Kind::kNegation, "OpSNegate", "int"},
+ UnaryTestCase{kF32, ir::Unary::Kind::kNegation, "OpFNegate", "float"},
+ UnaryTestCase{kF16, ir::Unary::Kind::kNegation, "OpFNegate", "half"}));
+
+} // namespace
+} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index a79e326..9bd5d06 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "src/tint/type/pointer.h"
+#include "src/tint/type/sampled_texture.h"
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
namespace tint::writer::spirv {
@@ -23,599 +24,310 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- b.Var(ty.ptr<function, i32>());
+ b.Var("v", ty.ptr<function, i32>());
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Function_int Function");
}
TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
v->SetInitializer(b.Constant(42_i));
-
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%8 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-OpStore %5 %8
-OpReturn
-OpFunctionEnd
-)");
-}
-
-TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
- auto* func = b.Function("foo", ty.void_());
-
- b.With(func->Block(), [&] {
- b.Var("myvar", ty.ptr<function, i32>());
- b.Return(func);
- });
-
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-OpName %5 "myvar"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Function_int Function");
+ EXPECT_INST("OpStore %v %int_42");
}
TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
auto* i = b.If(true);
b.With(i->True(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
v->SetInitializer(b.Constant(42_i));
b.ExitIf(i);
});
- b.With(i->False(), [&] { b.Return(func); });
-
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%9 = OpTypeBool
-%8 = OpConstantTrue %9
-%12 = OpTypeInt 32 1
-%11 = OpTypePointer Function %12
-%13 = OpConstant %12 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%10 = OpVariable %11 Function
-OpSelectionMerge %5 None
-OpBranchConditional %8 %6 %7
-%6 = OpLabel
-OpStore %10 %13
-OpBranch %5
-%7 = OpLabel
-OpReturn
-%5 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %foo = OpFunction %void None %3
+ %4 = OpLabel
+ %v = OpVariable %_ptr_Function_int Function
+ OpSelectionMerge %5 None
+ OpBranchConditional %true %6 %5
+ %6 = OpLabel
+ OpStore %v %int_42
+ OpBranch %5
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Load) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* store_ty = ty.i32();
- auto* v = b.Var(ty.ptr(function, store_ty));
- b.Load(v);
+ auto* v = b.Var("v", ty.ptr<function, i32>());
+ auto* result = b.Load(v);
b.Return(func);
+ mod.SetName(result, "result");
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-%8 = OpLoad %7 %5
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Function_int Function");
+ EXPECT_INST("%result = OpLoad %int %v");
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
auto* func = b.Function("foo", ty.void_());
-
b.With(func->Block(), [&] {
- auto* v = b.Var(ty.ptr<function, i32>());
+ auto* v = b.Var("v", ty.ptr<function, i32>());
b.Store(v, 42_i);
b.Return(func);
});
- ASSERT_TRUE(IRIsValid()) << Error();
-
- generator_.EmitFunction(func);
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
-%2 = OpTypeVoid
-%3 = OpTypeFunction %2
-%7 = OpTypeInt 32 1
-%6 = OpTypePointer Function %7
-%8 = OpConstant %7 42
-%1 = OpFunction %2 None %3
-%4 = OpLabel
-%5 = OpVariable %6 Function
-OpStore %5 %8
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Function_int Function");
+ EXPECT_INST("OpStore %v %int_42");
}
TEST_F(SpvGeneratorImplTest, PrivateVar_NoInit) {
- b.RootBlock()->Append(b.Var(ty.ptr<private_, i32>()));
+ b.RootBlock()->Append(b.Var("v", ty.ptr<private_, i32>()));
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %4 "unused_entry_point"
-OpExecutionMode %4 LocalSize 1 1 1
-OpName %4 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Private %3
-%1 = OpVariable %2 Private
-%5 = OpTypeVoid
-%6 = OpTypeFunction %5
-%4 = OpFunction %5 None %6
-%7 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Private_int Private");
}
TEST_F(SpvGeneratorImplTest, PrivateVar_WithInit) {
- auto* v = b.Var(ty.ptr<private_, i32>());
+ auto* v = b.Var("v", ty.ptr<private_, i32>());
v->SetInitializer(b.Constant(42_i));
b.RootBlock()->Append(v);
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpName %5 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Private %3
-%4 = OpConstant %3 42
-%1 = OpVariable %2 Private %4
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
-}
-
-TEST_F(SpvGeneratorImplTest, PrivateVar_Name) {
- auto* v = b.Var("myvar", ty.ptr<private_, i32>());
- v->SetInitializer(b.Constant(42_i));
- b.RootBlock()->Append(v);
-
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpName %1 "myvar"
-OpName %5 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Private %3
-%4 = OpConstant %3 42
-%1 = OpVariable %2 Private %4
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Private_int Private %int_42");
}
TEST_F(SpvGeneratorImplTest, PrivateVar_LoadAndStore) {
- auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
-
- auto* store_ty = ty.i32();
- auto* v = b.Var(ty.ptr(private_, store_ty));
+ auto* v = b.Var("v", ty.ptr<private_, i32>());
v->SetInitializer(b.Constant(42_i));
b.RootBlock()->Append(v);
+ auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
b.With(func->Block(), [&] {
- b.Load(v);
- auto* add = b.Add(store_ty, v, 1_i);
+ auto* load = b.Load(v);
+ auto* add = b.Add(ty.i32(), load, 1_i);
b.Store(v, add);
b.Return(func);
+ mod.SetName(load, "load");
+ mod.SetName(add, "add");
});
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %5 "foo"
-OpExecutionMode %5 OriginUpperLeft
-OpName %5 "foo"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Private %3
-%4 = OpConstant %3 42
-%1 = OpVariable %2 Private %4
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%11 = OpConstant %3 1
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-%9 = OpLoad %3 %1
-%10 = OpIAdd %3 %1 %11
-OpStore %1 %10
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Private_int Private %int_42");
+ EXPECT_INST("%load = OpLoad %int %v");
+ EXPECT_INST("OpStore %v %add");
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
- b.RootBlock()->Append(b.Var(ty.ptr<workgroup, i32>()));
+ b.RootBlock()->Append(b.Var("v", ty.ptr<workgroup, i32>()));
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %4 "unused_entry_point"
-OpExecutionMode %4 LocalSize 1 1 1
-OpName %4 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Workgroup %3
-%1 = OpVariable %2 Workgroup
-%5 = OpTypeVoid
-%6 = OpTypeFunction %5
-%4 = OpFunction %5 None %6
-%7 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
-}
-
-TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
- b.RootBlock()->Append(b.Var("myvar", ty.ptr<workgroup, i32>()));
-
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %4 "unused_entry_point"
-OpExecutionMode %4 LocalSize 1 1 1
-OpName %1 "myvar"
-OpName %4 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Workgroup %3
-%1 = OpVariable %2 Workgroup
-%5 = OpTypeVoid
-%6 = OpTypeFunction %5
-%4 = OpFunction %5 None %6
-%7 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Workgroup_int Workgroup");
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
+ auto* v = b.RootBlock()->Append(b.Var("v", ty.ptr<workgroup, i32>()));
+
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
-
- auto* store_ty = ty.i32();
- auto* v = b.RootBlock()->Append(b.Var(ty.ptr(workgroup, store_ty)));
-
b.With(func->Block(), [&] {
- b.Load(v);
- auto* add = b.Add(store_ty, v, 1_i);
+ auto* load = b.Load(v);
+ auto* add = b.Add(ty.i32(), load, 1_i);
b.Store(v, add);
b.Return(func);
+ mod.SetName(load, "load");
+ mod.SetName(add, "add");
});
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %4 "foo"
-OpExecutionMode %4 LocalSize 1 1 1
-OpName %4 "foo"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Workgroup %3
-%1 = OpVariable %2 Workgroup
-%5 = OpTypeVoid
-%6 = OpTypeFunction %5
-%10 = OpConstant %3 1
-%4 = OpFunction %5 None %6
-%7 = OpLabel
-%8 = OpLoad %3 %1
-%9 = OpIAdd %3 %1 %10
-OpStore %1 %9
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%v = OpVariable %_ptr_Workgroup_int Workgroup");
+ EXPECT_INST("%load = OpLoad %int %v");
+ EXPECT_INST("OpStore %v %add");
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
- b.RootBlock()->Append(b.Var(ty.ptr<workgroup, i32>()));
+ b.RootBlock()->Append(b.Var("v", ty.ptr<workgroup, i32>()));
// Create a generator with the zero_init_workgroup_memory flag set to `true`.
spirv::GeneratorImplIr gen(&mod, true);
- ASSERT_TRUE(gen.Generate()) << gen.Diagnostics().str();
- EXPECT_EQ(DumpModule(gen.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpName %5 "unused_entry_point"
-%3 = OpTypeInt 32 1
-%2 = OpTypePointer Workgroup %3
-%4 = OpConstantNull %3
-%1 = OpVariable %2 Workgroup %4
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
-)");
+ ASSERT_TRUE(Generate(gen)) << Error() << output_;
+ EXPECT_INST("%4 = OpConstantNull %int");
+ EXPECT_INST("%v = OpVariable %_ptr_Workgroup_int Workgroup %4");
}
TEST_F(SpvGeneratorImplTest, StorageVar) {
- auto* v = b.Var(ty.ptr<storage, i32>());
+ auto* v = b.Var("v", ty.ptr<storage, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "unused_entry_point"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer StorageBuffer %3
-%1 = OpVariable %2 StorageBuffer
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpDecorate %tint_symbol_1 Block
+ OpDecorate %1 DescriptorSet 0
+ OpDecorate %1 Binding 0
)");
-}
-
-TEST_F(SpvGeneratorImplTest, StorageVar_Name) {
- auto* v = b.Var("myvar", ty.ptr<storage, i32>());
- v->SetBindingPoint(0, 0);
- b.RootBlock()->Append(v);
-
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "unused_entry_point"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer StorageBuffer %3
-%1 = OpVariable %2 StorageBuffer
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ EXPECT_INST(R"(
+%tint_symbol_1 = OpTypeStruct %int
+%_ptr_StorageBuffer_tint_symbol_1 = OpTypePointer StorageBuffer %tint_symbol_1
+ %1 = OpVariable %_ptr_StorageBuffer_tint_symbol_1 StorageBuffer
)");
}
TEST_F(SpvGeneratorImplTest, StorageVar_LoadAndStore) {
- auto* v = b.Var(ty.ptr<storage, i32>());
+ auto* v = b.Var("v", ty.ptr<storage, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
-
b.With(func->Block(), [&] {
- b.Load(v);
- auto* add = b.Add(ty.i32(), v, 1_i);
+ auto* load = b.Load(v);
+ auto* add = b.Add(ty.i32(), load, 1_i);
b.Store(v, add);
b.Return(func);
+ mod.SetName(load, "load");
+ mod.SetName(add, "add");
});
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "foo"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "foo"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer StorageBuffer %3
-%1 = OpVariable %2 StorageBuffer
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%10 = OpTypePointer StorageBuffer %4
-%12 = OpTypeInt 32 0
-%11 = OpConstant %12 0
-%16 = OpConstant %4 1
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-%9 = OpAccessChain %10 %1 %11
-%13 = OpLoad %4 %9
-%14 = OpAccessChain %10 %1 %11
-%15 = OpIAdd %4 %14 %16
-%17 = OpAccessChain %10 %1 %11
-OpStore %17 %15
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %9 = OpAccessChain %_ptr_StorageBuffer_int %1 %uint_0
+ %load = OpLoad %int %9
+ %add = OpIAdd %int %load %int_1
+ %16 = OpAccessChain %_ptr_StorageBuffer_int %1 %uint_0
+ OpStore %16 %add
)");
}
TEST_F(SpvGeneratorImplTest, UniformVar) {
- auto* v = b.Var(ty.ptr<uniform, i32>());
+ auto* v = b.Var("v", ty.ptr<uniform, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "unused_entry_point"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer Uniform %3
-%1 = OpVariable %2 Uniform
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpDecorate %tint_symbol_1 Block
+ OpDecorate %1 DescriptorSet 0
+ OpDecorate %1 Binding 0
)");
-}
-
-TEST_F(SpvGeneratorImplTest, UniformVar_Name) {
- auto* v = b.Var("myvar", ty.ptr<uniform, i32>());
- v->SetBindingPoint(0, 0);
- b.RootBlock()->Append(v);
-
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "unused_entry_point"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "unused_entry_point"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer Uniform %3
-%1 = OpVariable %2 Uniform
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-OpReturn
-OpFunctionEnd
+ EXPECT_INST(R"(
+%tint_symbol_1 = OpTypeStruct %int
+%_ptr_Uniform_tint_symbol_1 = OpTypePointer Uniform %tint_symbol_1
+ %1 = OpVariable %_ptr_Uniform_tint_symbol_1 Uniform
)");
}
TEST_F(SpvGeneratorImplTest, UniformVar_Load) {
- auto* v = b.Var(ty.ptr<uniform, i32>());
+ auto* v = b.Var("v", ty.ptr<uniform, i32>());
v->SetBindingPoint(0, 0);
b.RootBlock()->Append(v);
auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
std::array{1u, 1u, 1u});
-
b.With(func->Block(), [&] {
- b.Load(v);
+ auto* load = b.Load(v);
b.Return(func);
+ mod.SetName(load, "load");
});
- ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
- EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
-OpMemoryModel Logical GLSL450
-OpEntryPoint GLCompute %5 "foo"
-OpExecutionMode %5 LocalSize 1 1 1
-OpMemberName %3 0 "tint_symbol"
-OpName %3 "tint_symbol_1"
-OpName %5 "foo"
-OpMemberDecorate %3 0 Offset 0
-OpDecorate %3 Block
-OpDecorate %1 DescriptorSet 0
-OpDecorate %1 Binding 0
-%4 = OpTypeInt 32 1
-%3 = OpTypeStruct %4
-%2 = OpTypePointer Uniform %3
-%1 = OpVariable %2 Uniform
-%6 = OpTypeVoid
-%7 = OpTypeFunction %6
-%10 = OpTypePointer Uniform %4
-%12 = OpTypeInt 32 0
-%11 = OpConstant %12 0
-%5 = OpFunction %6 None %7
-%8 = OpLabel
-%9 = OpAccessChain %10 %1 %11
-%13 = OpLoad %4 %9
-OpReturn
-OpFunctionEnd
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ %9 = OpAccessChain %_ptr_Uniform_int %1 %uint_0
+ %load = OpLoad %int %9
)");
}
+TEST_F(SpvGeneratorImplTest, SamplerVar) {
+ auto* v =
+ b.Var("v", ty.ptr(builtin::AddressSpace::kHandle, ty.sampler(), builtin::Access::kRead));
+ v->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(v);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpDecorate %v DescriptorSet 0
+ OpDecorate %v Binding 0
+)");
+ EXPECT_INST(R"(
+ %3 = OpTypeSampler
+%_ptr_UniformConstant_3 = OpTypePointer UniformConstant %3
+ %v = OpVariable %_ptr_UniformConstant_3 UniformConstant
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, SamplerVar_Load) {
+ auto* v =
+ b.Var("v", ty.ptr(builtin::AddressSpace::kHandle, ty.sampler(), builtin::Access::kRead));
+ v->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(v);
+
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* load = b.Load(v);
+ b.Return(func);
+ mod.SetName(load, "load");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%load = OpLoad %3 %v");
+}
+
+TEST_F(SpvGeneratorImplTest, TextureVar) {
+ auto* v = b.Var("v", ty.ptr(builtin::AddressSpace::kHandle,
+ ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()),
+ builtin::Access::kRead));
+ v->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(v);
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST(R"(
+ OpDecorate %v DescriptorSet 0
+ OpDecorate %v Binding 0
+)");
+ EXPECT_INST(R"(
+ %3 = OpTypeImage %float 2D 0 0 0 1 Unknown
+%_ptr_UniformConstant_3 = OpTypePointer UniformConstant %3
+ %v = OpVariable %_ptr_UniformConstant_3 UniformConstant
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, TextureVar_Load) {
+ auto* v = b.Var("v", ty.ptr(builtin::AddressSpace::kHandle,
+ ty.Get<type::SampledTexture>(type::TextureDimension::k2d, ty.f32()),
+ builtin::Access::kRead));
+ v->SetBindingPoint(0, 0);
+ b.RootBlock()->Append(v);
+
+ auto* func = b.Function("foo", ty.void_());
+ b.With(func->Block(), [&] {
+ auto* load = b.Load(v);
+ b.Return(func);
+ mod.SetName(load, "load");
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("%load = OpLoad %3 %v");
+}
+
} // namespace
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h
index 27559be..dfc05b8 100644
--- a/src/tint/writer/spirv/ir/test_helper_ir.h
+++ b/src/tint/writer/spirv/ir/test_helper_ir.h
@@ -17,7 +17,9 @@
#include <string>
#include <utility>
+#include <vector>
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "spirv-tools/libspirv.hpp"
#include "src/tint/ir/builder.h"
@@ -39,6 +41,26 @@
kF32,
kF16,
};
+inline utils::StringStream& operator<<(utils::StringStream& out, TestElementType type) {
+ switch (type) {
+ case kBool:
+ out << "bool";
+ break;
+ case kI32:
+ out << "i32";
+ break;
+ case kU32:
+ out << "u32";
+ break;
+ case kF32:
+ out << "f32";
+ break;
+ case kF16:
+ out << "f16";
+ break;
+ }
+ return out;
+}
/// Base helper class for testing the SPIR-V generator implementation.
template <typename BASE>
@@ -57,7 +79,7 @@
/// The SPIR-V generator.
GeneratorImplIr generator_;
- /// Validation errors
+ /// Errors produced during codegen or SPIR-V validation.
std::string err_;
/// SPIR-V output.
@@ -66,38 +88,34 @@
/// @returns the error string from the validation
std::string Error() const { return err_; }
- /// @returns true if the IR module is valid
- bool IRIsValid() {
- auto res = ir::Validate(mod);
- if (!res) {
- err_ = res.Failure().str();
+ /// Run the specified generator on the IR module and validate the result.
+ /// @param generator the generator to use for SPIR-V generation
+ /// @returns true if generation and validation succeeded
+ bool Generate(GeneratorImplIr& generator) {
+ if (!generator.Generate()) {
+ err_ = generator.Diagnostics().str();
return false;
}
+
+ output_ = Disassemble(generator.Result(), SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
+ SPV_BINARY_TO_TEXT_OPTION_INDENT |
+ SPV_BINARY_TO_TEXT_OPTION_COMMENT);
+
+ if (!Validate(generator.Result())) {
+ return false;
+ }
+
return true;
}
/// Run the generator on the IR module and validate the result.
/// @returns true if generation and validation succeeded
- bool Generate() {
- if (!generator_.Generate()) {
- err_ = generator_.Diagnostics().str();
- return false;
- }
- if (!Validate()) {
- return false;
- }
-
- output_ = Disassemble(generator_.Result(), SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
- SPV_BINARY_TO_TEXT_OPTION_INDENT |
- SPV_BINARY_TO_TEXT_OPTION_COMMENT);
- return true;
- }
+ bool Generate() { return Generate(generator_); }
/// Validate the generated SPIR-V using the SPIR-V Tools Validator.
+ /// @param binary the SPIR-V binary module to validate
/// @returns true if validation succeeded, false otherwise
- bool Validate() {
- auto binary = generator_.Result();
-
+ bool Validate(const std::vector<uint32_t>& binary) {
std::string spv_errors;
auto msg_consumer = [&spv_errors](spv_message_level_t level, const char*,
const spv_position_t& position, const char* message) {
diff --git a/src/tint/writer/syntax_tree/generator_impl.cc b/src/tint/writer/syntax_tree/generator_impl.cc
index b6ba3ea..6ef2395 100644
--- a/src/tint/writer/syntax_tree/generator_impl.cc
+++ b/src/tint/writer/syntax_tree/generator_impl.cc
@@ -521,12 +521,12 @@
Line() << "]";
},
[&](const ast::IndexAttribute* index) {
- line() << "IndexAttribute [";
+ Line() << "IndexAttribute [";
{
ScopedIndent idx(this);
EmitExpression(index->expr);
}
- line() << "]";
+ Line() << "]";
},
[&](const ast::BuiltinAttribute* builtin) {
Line() << "BuiltinAttribute [";