[ir] Cleanup composite creation in tests
This CL adds some helpers to make composites easier to use in tests.
Bug: tint:1718
Change-Id: I16a0e94978c43efa619b31b6815089c8fff6983f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133920
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 976d866..78bcb9b 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -46,6 +46,7 @@
#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
#include "src/tint/type/u32.h"
+#include "src/tint/type/vector.h"
#include "src/tint/type/void.h"
namespace tint::ir {
@@ -119,6 +120,41 @@
return ir.constants_arena.Create<T>(std::forward<ARGS>(args)...);
}
+ /// @param v the value
+ /// @returns the constant value
+ const constant::Value* Bool(bool v) {
+ // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
+ return Constant(create<constant::Scalar<bool>>(ir.types.Get<type::Bool>(), v))->Value();
+ }
+
+ /// @param v the value
+ /// @returns the constant value
+ const constant::Value* U32(uint32_t v) {
+ // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
+ return Constant(create<constant::Scalar<u32>>(ir.types.Get<type::U32>(), u32(v)))->Value();
+ }
+
+ /// @param v the value
+ /// @returns the constant value
+ const constant::Value* I32(int32_t v) {
+ // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
+ return Constant(create<constant::Scalar<i32>>(ir.types.Get<type::I32>(), i32(v)))->Value();
+ }
+
+ /// @param v the value
+ /// @returns the constant value
+ const constant::Value* F16(float v) {
+ // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
+ return Constant(create<constant::Scalar<f16>>(ir.types.Get<type::F16>(), f16(v)))->Value();
+ }
+
+ /// @param v the value
+ /// @returns the constant value
+ const constant::Value* F32(float v) {
+ // TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
+ return Constant(create<constant::Scalar<f32>>(ir.types.Get<type::F32>(), f32(v)))->Value();
+ }
+
/// Creates a new ir::Constant
/// @param val the constant value
/// @returns the new constant
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index 0650f1b..4eb48bb 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -18,6 +18,7 @@
#include <utility>
#include "src/tint/type/type.h"
+#include "src/tint/type/vector.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/unique_allocator.h"
@@ -84,6 +85,23 @@
return types_.Find<TYPE>(std::forward<ARGS>(args)...);
}
+ /// @param inner the inner type
+ /// @param size the vector size
+ /// @returns the vector type
+ type::Type* vec(type::Type* inner, uint32_t size) { return Get<type::Vector>(inner, size); }
+
+ /// @param inner the inner type
+ /// @returns the vector type
+ type::Type* vec2(type::Type* inner) { return vec(inner, 2); }
+
+ /// @param inner the inner type
+ /// @returns the vector type
+ type::Type* vec3(type::Type* inner) { return vec(inner, 3); }
+
+ /// @param inner the inner type
+ /// @returns the vector type
+ type::Type* vec4(type::Type* inner) { return vec(inner, 4); }
+
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 6c131c5..8792239 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -147,12 +147,10 @@
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) {
auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
- auto* lhs = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
- utils::Vector{b.Constant(42_i)->Value(), b.Constant(-1_i)->Value()}, false, false);
- auto* rhs = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
- utils::Vector{b.Constant(0_i)->Value(), b.Constant(-43_i)->Value()}, false, false);
+ auto* lhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
+ utils::Vector{b.I32(42), b.I32(-1)}, false, false);
+ auto* rhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
+ utils::Vector{b.I32(0), b.I32(-43)}, false, false);
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
b.Constant(lhs), b.Constant(rhs)),
@@ -180,16 +178,12 @@
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) {
auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
- auto* lhs = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
- utils::Vector{b.Constant(42_f)->Value(), b.Constant(-1_f)->Value(),
- b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value()},
- false, false);
- auto* rhs = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
- utils::Vector{b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value(),
- b.Constant(-42_f)->Value(), b.Constant(1_f)->Value()},
- false, false);
+ auto* lhs = b.create<constant::Composite>(
+ mod.types.vec4(mod.types.Get<type::F32>()),
+ utils::Vector{b.F32(42), b.F32(-1), b.F32(0), b.F32(1.25)}, false, false);
+ auto* rhs = b.create<constant::Composite>(
+ mod.types.vec4(mod.types.Get<type::F32>()),
+ utils::Vector{b.F32(0), b.F32(1.25), b.F32(-42), b.F32(1)}, false, false);
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
b.Constant(lhs), b.Constant(rhs)),
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
index 6ab48aa..75775ff 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc
@@ -63,11 +63,10 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
- auto* t = b.Constant(true);
- auto* f = b.Constant(false);
- auto* v = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(mod.types.Get<type::Bool>(), 4u),
- utils::Vector{t->Value(), f->Value(), f->Value(), t->Value()}, false, true);
+ auto* v = b.create<constant::Composite>(
+ mod.types.vec4(mod.types.Get<type::Bool>()),
+ utils::Vector{b.Bool(true), b.Bool(false), b.Bool(false), b.Bool(true)}, false, true);
+
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool
%2 = OpTypeVector %3 4
@@ -78,12 +77,8 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
- auto* i = mod.types.Get<type::I32>();
- auto* i_42 = b.Constant(i32(42));
- auto* i_n1 = b.Constant(i32(-1));
- auto* v = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(i, 2u), utils::Vector{i_42->Value(), i_n1->Value()}, false,
- false);
+ auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
+ utils::Vector{b.I32(42), b.I32(-1)}, false, false);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
%2 = OpTypeVector %3 2
@@ -94,13 +89,9 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
- auto* u = mod.types.Get<type::U32>();
- auto* u_42 = b.Constant(u32(42));
- auto* u_0 = b.Constant(u32(0));
- auto* u_4b = b.Constant(u32(4000000000));
- auto* v = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(u, 3u),
- utils::Vector{u_42->Value(), u_0->Value(), u_4b->Value()}, false, true);
+ auto* v = b.create<constant::Composite>(mod.types.vec3(mod.types.Get<type::U32>()),
+ utils::Vector{b.U32(42), b.U32(0), b.U32(4000000000)},
+ false, true);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0
%2 = OpTypeVector %3 3
@@ -112,14 +103,9 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
- auto* f = mod.types.Get<type::F32>();
- auto* f_42 = b.Constant(f32(42));
- auto* f_0 = b.Constant(f32(0));
- auto* f_q = b.Constant(f32(0.25));
- auto* f_n1 = b.Constant(f32(-1));
- auto* v = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(f, 4u),
- utils::Vector{f_42->Value(), f_0->Value(), f_q->Value(), f_n1->Value()}, false, true);
+ auto* v = b.create<constant::Composite>(
+ mod.types.vec4(mod.types.Get<type::F32>()),
+ utils::Vector{b.F32(42), b.F32(0), b.F32(0.25), b.F32(-1)}, false, true);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 4
@@ -132,12 +118,8 @@
}
TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
- auto* h = mod.types.Get<type::F16>();
- auto* h_42 = b.Constant(f16(42));
- auto* h_q = b.Constant(f16(0.25));
- auto* v = mod.constants_arena.Create<constant::Composite>(
- mod.types.Get<type::Vector>(h, 2u), utils::Vector{h_42->Value(), h_q->Value()}, false,
- false);
+ auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::F16>()),
+ utils::Vector{b.F16(42), b.F16(0.25)}, false, false);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16
%2 = OpTypeVector %3 2