[ir][ms] Emit Lets
Add emission of `let` instructions to the MSL IR generator.
Bug: tint:1967
Change-Id: I5b6af981bce7ee4e497b101c95a1fc662f70cc3e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144042
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 2048955..a71257c 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -57,10 +57,12 @@
#include "src/tint/lang/core/ir/user_call.h"
#include "src/tint/lang/core/ir/value.h"
#include "src/tint/lang/core/ir/var.h"
+#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/bool.h"
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/u32.h"
#include "src/tint/lang/core/type/vector.h"
@@ -219,29 +221,72 @@
/// Creates a ir::Constant for an i32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(i32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a u32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(u32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a f32 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(f32 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a f16 Scalar
/// @param v the value
/// @returns the new constant
- ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); }
+ ir::Constant* Constant(f16 v) { return Constant(ConstantValue(v)); }
/// Creates a ir::Constant for a bool Scalar
/// @param v the value
/// @returns the new constant
template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
ir::Constant* Constant(BOOL v) {
- return Constant(ir.constant_values.Get(v));
+ return Constant(ConstantValue(v));
+ }
+
+ /// Retrieves the inner constant from an ir::Constant
+ /// @param constant the ir constant
+ /// @returns the constant::Value inside the constant
+ const constant::Value* ConstantValue(ir::Constant* constant) { return constant->Value(); }
+
+ /// Creates a constant::Value for an i32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(i32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a u32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(u32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a f32 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(f32 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a f16 Scalar
+ /// @param v the value
+ /// @returns the new constant
+ const constant::Value* ConstantValue(f16 v) { return ir.constant_values.Get(v); }
+
+ /// Creates a constant::Value for a bool Scalar
+ /// @param v the value
+ /// @returns the new constant
+ template <typename BOOL, typename = std::enable_if_t<std::is_same_v<BOOL, bool>>>
+ const constant::Value* ConstantValue(BOOL v) {
+ return ir.constant_values.Get(v);
+ }
+
+ /// Creates a new ir::Constant
+ /// @param ty the constant type
+ /// @param values the composite values
+ /// @returns the new constant
+ template <typename... ARGS, typename = DisableIfVectorLike<ARGS...>>
+ ir::Constant* Composite(const type::Type* ty, ARGS&&... values) {
+ return Constant(
+ ir.constant_values.Composite(ty, Vector{ConstantValue(std::forward<ARGS>(values))...}));
}
/// @param in the input value. One of: nullptr, ir::Value*, ir::Instruction* or a numeric value.