[tint][core] Splat: Infer count from type
This is a redundant parameter, as it can be inferred from the type argument.
Reduces risk of getting this wrong (which a few tests did)
Bug: 342096120
Change-Id: I201e76314ee5cfc5bb22692ba20eeaaa83b9e948
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/189123
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 105ed38..ee896b7 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -359,7 +359,7 @@
if (auto* build = std::get_if<ActionBuildSplat>(&next)) {
TINT_ASSERT(value_stack.Length() >= 1);
auto* el = value_stack.Pop();
- value_stack.Push(ctx.mgr.Splat(build->type, el, build->count));
+ value_stack.Push(ctx.mgr.Splat(build->type, el));
continue;
}
@@ -1331,7 +1331,7 @@
VectorRef<const Value*> args,
const Source&) {
if (auto* arg = args[0]) {
- return mgr.Splat(ty, arg, static_cast<const core::type::Vector*>(ty)->Width());
+ return mgr.Splat(ty, arg);
}
return nullptr;
}
diff --git a/src/tint/lang/core/constant/manager.cc b/src/tint/lang/core/constant/manager.cc
index 69feff2..ecd4a38 100644
--- a/src/tint/lang/core/constant/manager.cc
+++ b/src/tint/lang/core/constant/manager.cc
@@ -65,7 +65,7 @@
bool all_equal = true;
auto* first = elements.Front();
for (auto* el : elements) {
- if (!el) {
+ if (TINT_UNLIKELY(!el)) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
@@ -79,16 +79,15 @@
}
}
if (all_equal) {
- return Splat(type, elements.Front(), elements.Length());
+ return Splat(type, elements.Front());
}
return Get<constant::Composite>(type, std::move(elements), all_zero, any_zero);
}
const constant::Splat* Manager::Splat(const core::type::Type* type,
- const constant::Value* element,
- size_t n) {
- return Get<constant::Splat>(type, element, n);
+ const constant::Value* element) {
+ return Get<constant::Splat>(type, element);
}
const Scalar<i32>* Manager::Get(i32 value) {
@@ -124,16 +123,16 @@
type, //
[&](const core::type::Vector* v) -> const Value* {
auto* zero_el = Zero(v->type());
- return Splat(type, zero_el, v->Width());
+ return Splat(type, zero_el);
},
[&](const core::type::Matrix* m) -> const Value* {
auto* zero_el = Zero(m->ColumnType());
- return Splat(type, zero_el, m->columns());
+ return Splat(type, zero_el);
},
[&](const core::type::Array* a) -> const Value* {
- if (auto n = a->ConstantCount()) {
+ if (a->ConstantCount()) {
if (auto* zero_el = Zero(a->ElemType())) {
- return Splat(type, zero_el, n.value());
+ return Splat(type, zero_el);
}
}
return nullptr;
@@ -152,7 +151,7 @@
}
if (zero_by_type.Count() == 1) {
// All members were of the same type, so the zero value is the same for all members.
- return Splat(type, zeros[0], s->Members().Length());
+ return Splat(type, zeros[0]);
}
return Composite(s, std::move(zeros));
},
diff --git a/src/tint/lang/core/constant/manager.h b/src/tint/lang/core/constant/manager.h
index 2eb5088..43497db 100644
--- a/src/tint/lang/core/constant/manager.h
+++ b/src/tint/lang/core/constant/manager.h
@@ -108,11 +108,8 @@
/// Constructs a splat constant.
/// @param type the splat type
/// @param element the splat element
- /// @param n the number of elements
/// @returns the value pointer
- const constant::Splat* Splat(const core::type::Type* type,
- const constant::Value* element,
- size_t n);
+ const constant::Splat* Splat(const core::type::Type* type, const constant::Value* element);
/// @param value the constant value
/// @return a Scalar holding the i32 value @p value
diff --git a/src/tint/lang/core/constant/splat.cc b/src/tint/lang/core/constant/splat.cc
index 07d8fbb..d34be7b 100644
--- a/src/tint/lang/core/constant/splat.cc
+++ b/src/tint/lang/core/constant/splat.cc
@@ -33,15 +33,15 @@
namespace tint::core::constant {
-Splat::Splat(const core::type::Type* t, const constant::Value* e, size_t n)
- : type(t), el(e), count(n) {}
+Splat::Splat(const core::type::Type* t, const constant::Value* e)
+ : type(t), el(e), count(t->Elements().count) {}
Splat::~Splat() = default;
const Splat* Splat::Clone(CloneContext& ctx) const {
auto* ty = type->Clone(ctx.type_ctx);
auto* element = el->Clone(ctx);
- return ctx.dst.Splat(ty, element, count);
+ return ctx.dst.Splat(ty, element);
}
} // namespace tint::core::constant
diff --git a/src/tint/lang/core/constant/splat.h b/src/tint/lang/core/constant/splat.h
index bf7a0c4..def9bfe 100644
--- a/src/tint/lang/core/constant/splat.h
+++ b/src/tint/lang/core/constant/splat.h
@@ -44,8 +44,7 @@
/// Constructor
/// @param t the splat type
/// @param e the splat element
- /// @param n the number of items in the splat
- Splat(const core::type::Type* t, const Value* e, size_t n);
+ Splat(const core::type::Type* t, const Value* e);
~Splat() override;
/// @returns the type of the splat
diff --git a/src/tint/lang/core/constant/splat_test.cc b/src/tint/lang/core/constant/splat_test.cc
index d45f614..553cf14 100644
--- a/src/tint/lang/core/constant/splat_test.cc
+++ b/src/tint/lang/core/constant/splat_test.cc
@@ -46,9 +46,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1);
EXPECT_TRUE(SpfPos0->AllZero());
EXPECT_TRUE(SpfNeg0->AllZero());
@@ -62,9 +62,9 @@
auto* fNeg0 = constants.Get(-0_f);
auto* fPos1 = constants.Get(1_f);
- auto* SpfPos0 = constants.Splat(vec3f, fPos0, 3);
- auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 3);
- auto* SpfPos1 = constants.Splat(vec3f, fPos1, 3);
+ auto* SpfPos0 = constants.Splat(vec3f, fPos0);
+ auto* SpfNeg0 = constants.Splat(vec3f, fNeg0);
+ auto* SpfPos1 = constants.Splat(vec3f, fPos1);
EXPECT_TRUE(SpfPos0->AnyZero());
EXPECT_TRUE(SpfNeg0->AnyZero());
@@ -75,20 +75,21 @@
auto* vec3f = create<core::type::Vector>(create<core::type::F32>(), 3u);
auto* f1 = constants.Get(1_f);
- auto* sp = constants.Splat(vec3f, f1, 2);
+ auto* sp = constants.Splat(vec3f, f1);
ASSERT_NE(sp->Index(0), nullptr);
ASSERT_NE(sp->Index(1), nullptr);
- ASSERT_EQ(sp->Index(2), nullptr);
+ ASSERT_NE(sp->Index(2), nullptr);
EXPECT_EQ(sp->Index(0)->As<Scalar<f32>>()->ValueOf(), 1.f);
EXPECT_EQ(sp->Index(1)->As<Scalar<f32>>()->ValueOf(), 1.f);
+ EXPECT_EQ(sp->Index(2)->As<Scalar<f32>>()->ValueOf(), 1.f);
}
TEST_F(ConstantTest_Splat, Clone) {
- auto* vec3i = create<core::type::Vector>(create<core::type::I32>(), 3u);
+ auto* vec2i = create<core::type::Vector>(create<core::type::I32>(), 2u);
auto* val = constants.Get(12_i);
- auto* sp = constants.Splat(vec3i, val, 2);
+ auto* sp = constants.Splat(vec2i, val);
constant::Manager mgr;
constant::CloneContext ctx{core::type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr};
diff --git a/src/tint/lang/core/constant/value_test.cc b/src/tint/lang/core/constant/value_test.cc
index ad0c297..6bb31fd 100644
--- a/src/tint/lang/core/constant/value_test.cc
+++ b/src/tint/lang/core/constant/value_test.cc
@@ -54,8 +54,8 @@
TEST_F(ConstantTest_Value, Equal_Splat_Splat) {
auto* vec3f = create<core::type::Vector>(create<core::type::F32>(), 3u);
- auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
- auto* vec3f_2_2_2 = constants.Splat(vec3f, constants.Get(2_f), 3);
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f));
+ auto* vec3f_2_2_2 = constants.Splat(vec3f, constants.Get(2_f));
EXPECT_TRUE(vec3f_1_1_1->Equal(vec3f_1_1_1));
EXPECT_FALSE(vec3f_2_2_2->Equal(vec3f_1_1_1));
@@ -78,7 +78,7 @@
TEST_F(ConstantTest_Value, Equal_Splat_Composite) {
auto* vec3f = create<core::type::Vector>(create<core::type::F32>(), 3u);
- auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f), 3);
+ auto* vec3f_1_1_1 = constants.Splat(vec3f, constants.Get(1_f));
auto* vec3f_1_2_1 = constants.Composite(
vec3f, Vector{constants.Get(1_f), constants.Get(2_f), constants.Get(1_f)});
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index c422433..780afd2 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -836,7 +836,7 @@
const core::constant::Value* CreateConstantSplat(const pb::ConstantValueSplat& splat_in) {
auto* type = Type(splat_in.type());
auto* elem = ConstantValue(splat_in.elements());
- return mod_out_.constant_values.Splat(type, elem, splat_in.count());
+ return mod_out_.constant_values.Splat(type, elem);
}
const core::constant::Value* ConstantValue(uint32_t id) {
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 76f92b8..033e316 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -396,7 +396,7 @@
TEST_F(IRBinaryRoundtripTest, Return_vec3f_Splat) {
auto* fn = b.Function("Function", ty.vec3<f32>());
- b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<vec3<f32>>(1_f, 3)); });
+ b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<vec3<f32>>(1_f)); });
RUN_TEST();
}
@@ -409,7 +409,7 @@
TEST_F(IRBinaryRoundtripTest, Return_mat2x3f_Splat) {
auto* fn = b.Function("Function", ty.mat2x3<f32>());
- b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<mat2x3<f32>>(1_f, 6)); });
+ b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<mat2x3<f32>>(b.Splat<vec3<f32>>(1_f))); });
RUN_TEST();
}
@@ -421,7 +421,7 @@
TEST_F(IRBinaryRoundtripTest, Return_array_f32_Splat) {
auto* fn = b.Function("Function", ty.array<f32, 3>());
- b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<array<f32, 3>>(1_i, 3)); });
+ b.Append(fn->Block(), [&] { b.Return(fn, b.Splat<array<f32, 3>>(1_i)); });
RUN_TEST();
}
diff --git a/src/tint/lang/core/ir/builder.h b/src/tint/lang/core/ir/builder.h
index 240afc1..e3b797b 100644
--- a/src/tint/lang/core/ir/builder.h
+++ b/src/tint/lang/core/ir/builder.h
@@ -397,23 +397,20 @@
/// Creates a new ir::Constant
/// @param ty the splat type
/// @param value the splat value
- /// @param size the number of items
/// @returns the new constant
template <typename ARG>
- ir::Constant* Splat(const core::type::Type* ty, ARG&& value, size_t size) {
- return Constant(
- ir.constant_values.Splat(ty, ConstantValue(std::forward<ARG>(value)), size));
+ ir::Constant* Splat(const core::type::Type* ty, ARG&& value) {
+ return Constant(ir.constant_values.Splat(ty, ConstantValue(std::forward<ARG>(value))));
}
/// Creates a new ir::Constant
/// @tparam TYPE the splat type
/// @param value the splat value
- /// @param size the number of items
/// @returns the new constant
template <typename TYPE, typename ARG>
- ir::Constant* Splat(ARG&& value, size_t size) {
+ ir::Constant* Splat(ARG&& value) {
auto* type = ir.Types().Get<TYPE>();
- return Splat(type, std::forward<ARG>(value), size);
+ return Splat(type, std::forward<ARG>(value));
}
/// Creates a new ir::Constant
@@ -918,7 +915,7 @@
template <typename VAL>
ir::CoreBinary* Not(const core::type::Type* type, VAL&& val) {
if (auto* vec = type->As<core::type::Vector>()) {
- return Equal(type, std::forward<VAL>(val), Splat(vec, false, vec->Width()));
+ return Equal(type, std::forward<VAL>(val), Splat(vec, false));
} else {
return Equal(type, std::forward<VAL>(val), Constant(false));
}
diff --git a/src/tint/lang/core/ir/transform/binary_polyfill.cc b/src/tint/lang/core/ir/transform/binary_polyfill.cc
index a5aec3c..e74a828 100644
--- a/src/tint/lang/core/ir/transform/binary_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/binary_polyfill.cc
@@ -126,7 +126,7 @@
/// @returns a value with the same number of vector components as @p match
ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
if (auto* vec = match->As<core::type::Vector>()) {
- return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
+ return b.Splat(MatchWidth(element->Type(), match), element);
}
return element;
}
diff --git a/src/tint/lang/core/ir/transform/builtin_polyfill.cc b/src/tint/lang/core/ir/transform/builtin_polyfill.cc
index ef68d5b..f62c7c6 100644
--- a/src/tint/lang/core/ir/transform/builtin_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/builtin_polyfill.cc
@@ -225,7 +225,7 @@
/// @returns a value with the same number of vector components as @p match
ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
if (auto* vec = match->As<core::type::Vector>()) {
- return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
+ return b.Splat(MatchWidth(element->Type(), match), element);
}
return element;
}
@@ -575,13 +575,12 @@
auto* sampler = call->Args()[1];
auto* coords = call->Args()[2];
b.InsertBefore(call, [&] {
- auto* vec2f = ty.vec2<f32>();
- auto* dims = b.Call(ty.vec2<u32>(), core::BuiltinFn::kTextureDimensions, texture);
- auto* fdims = b.Convert(vec2f, dims);
- auto* half_texel = b.Divide(vec2f, b.Splat(vec2f, 0.5_f, 2), fdims);
- auto* one_minus_half_texel = b.Subtract(vec2f, b.Splat(vec2f, 1_f, 2), half_texel);
- auto* clamped =
- b.Call(vec2f, core::BuiltinFn::kClamp, coords, half_texel, one_minus_half_texel);
+ auto* dims = b.Call<vec2<u32>>(core::BuiltinFn::kTextureDimensions, texture);
+ auto* fdims = b.Convert<vec2<f32>>(dims);
+ auto* half_texel = b.Divide<vec2<f32>>(b.Splat<vec2<f32>>(0.5_f), fdims);
+ auto* one_minus_half_texel = b.Subtract<vec2<f32>>(b.Splat<vec2<f32>>(1_f), half_texel);
+ auto* clamped = b.Call<vec2<f32>>(core::BuiltinFn::kClamp, coords, half_texel,
+ one_minus_half_texel);
b.CallWithResult(call->DetachResult(), core::BuiltinFn::kTextureSampleLevel, texture,
sampler, clamped, 0_f);
});
diff --git a/src/tint/lang/core/ir/transform/conversion_polyfill.cc b/src/tint/lang/core/ir/transform/conversion_polyfill.cc
index 76d1920..36ea764 100644
--- a/src/tint/lang/core/ir/transform/conversion_polyfill.cc
+++ b/src/tint/lang/core/ir/transform/conversion_polyfill.cc
@@ -207,7 +207,7 @@
/// @returns a value with the same number of vector components as @p match
ir::Constant* MatchWidth(ir::Constant* element, const core::type::Type* match) {
if (auto* vec = match->As<core::type::Vector>()) {
- return b.Splat(MatchWidth(element->Type(), match), element, vec->Width());
+ return b.Splat(MatchWidth(element->Type(), match), element);
}
return element;
}
diff --git a/src/tint/lang/core/ir/transform/demote_to_helper_test.cc b/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
index e26ff8d..db15893 100644
--- a/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
+++ b/src/tint/lang/core/ir/transform/demote_to_helper_test.cc
@@ -550,7 +550,7 @@
b.ExitIf(ifelse);
});
b.Call(ty.void_(), core::BuiltinFn::kTextureStore, b.Load(texture), coord,
- b.Splat(b.ir.Types().vec4<f32>(), 0.5_f, 4));
+ b.Splat(b.ir.Types().vec4<f32>(), 0.5_f));
b.Return(ep, 0.5_f);
});
diff --git a/src/tint/lang/core/ir/transform/direct_variable_access_test.cc b/src/tint/lang/core/ir/transform/direct_variable_access_test.cc
index cf464d3..fd89e61 100644
--- a/src/tint/lang/core/ir/transform/direct_variable_access_test.cc
+++ b/src/tint/lang/core/ir/transform/direct_variable_access_test.cc
@@ -1393,7 +1393,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.array<i32, 4>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<array<i32, 4>>(0_i));
b.Return(fn_a);
});
@@ -1475,7 +1475,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.vec4<i32>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<vec4<i32>>(0_i));
b.Return(fn_a);
});
@@ -2068,7 +2068,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.vec4<i32>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<vec4<i32>>(0_i));
b.Return(fn_a);
});
@@ -2870,7 +2870,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.array<i32, 4>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<array<i32, 4>>(0_i));
b.Return(fn_a);
});
@@ -2956,7 +2956,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.array<i32, 4>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<array<i32, 4>>(0_i));
b.Return(fn_a);
});
@@ -4138,7 +4138,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.array<i32, 4>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<array<i32, 4>>(0_i));
b.Return(fn_a);
});
@@ -4373,7 +4373,7 @@
b.FunctionParam("post", ty.i32()),
});
b.Append(fn_a->Block(), [&] {
- b.Store(fn_a_p, b.Splat(ty.array<i32, 4>(), 0_i, 4));
+ b.Store(fn_a_p, b.Splat<array<i32, 4>>(0_i));
b.Return(fn_a);
});
diff --git a/src/tint/lang/core/ir/transform/robustness.cc b/src/tint/lang/core/ir/transform/robustness.cc
index 9284b9a..56f2839 100644
--- a/src/tint/lang/core/ir/transform/robustness.cc
+++ b/src/tint/lang/core/ir/transform/robustness.cc
@@ -292,7 +292,7 @@
auto* one = b.Constant(1_u);
if (auto* vec = args[idx]->Type()->As<type::Vector>()) {
type = ty.vec(type, vec->Width());
- one = b.Splat(type, one, vec->Width());
+ one = b.Splat(type, one);
}
auto* dims = clamped_level ? b.Call(type, core::BuiltinFn::kTextureDimensions, args[0],
clamped_level)
diff --git a/src/tint/lang/msl/writer/printer/constant_test.cc b/src/tint/lang/msl/writer/printer/constant_test.cc
index e4910d5..1ce484b 100644
--- a/src/tint/lang/msl/writer/printer/constant_test.cc
+++ b/src/tint/lang/msl/writer/printer/constant_test.cc
@@ -134,7 +134,7 @@
}
TEST_F(MslPrinterTest, Constant_Vector_Splat) {
- auto* c = b.Splat(ty.vec3<f32>(), 1.5_f, 3);
+ auto* c = b.Splat<vec3<f32>>(1.5_f);
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
b.Let("a", c);
@@ -198,7 +198,7 @@
}
TEST_F(MslPrinterTest, Constant_Matrix_Splat) {
- auto* c = b.Splat(ty.mat3x2<f32>(), 1.5_f, 3);
+ auto* c = b.Splat<mat3x2<f32>>(1.5_f);
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
b.Let("a", c);
@@ -270,7 +270,7 @@
}
TEST_F(MslPrinterTest, Constant_Array_Splat) {
- auto* c = b.Splat(ty.array<f32, 3>(), 1.5_f, 3);
+ auto* c = b.Splat<array<f32, 3>>(1.5_f);
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
b.Let("a", c);
@@ -338,7 +338,7 @@
{mod.symbols.Register("a"), ty.f32()},
{mod.symbols.Register("b"), ty.f32()},
});
- auto* c = b.Splat(s, 1.5_f, 2);
+ auto* c = b.Splat(s, 1.5_f);
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
b.Let("a", c);
diff --git a/src/tint/lang/msl/writer/printer/var_test.cc b/src/tint/lang/msl/writer/printer/var_test.cc
index b3e141a..1bfdbdc 100644
--- a/src/tint/lang/msl/writer/printer/var_test.cc
+++ b/src/tint/lang/msl/writer/printer/var_test.cc
@@ -183,7 +183,7 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* v = b.Var("a", ty.ptr<core::AddressSpace::kFunction, vec3<f32>>());
- v->SetInitializer(b.Splat(ty.vec3<f32>(), 0_f, 3));
+ v->SetInitializer(b.Splat<vec3<f32>>(0_f));
b.Return(func);
});
@@ -200,7 +200,7 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* v = b.Var("a", ty.ptr<core::AddressSpace::kFunction, vec3<f16>>());
- v->SetInitializer(b.Splat(ty.vec3<f16>(), 0_h, 3));
+ v->SetInitializer(b.Splat<vec3<f16>>(0_h));
b.Return(func);
});
@@ -216,8 +216,8 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* v = b.Var("a", ty.ptr<core::AddressSpace::kFunction, mat2x3<f32>>());
- v->SetInitializer(b.Composite(ty.mat2x3<f32>(), b.Splat(ty.vec3<f32>(), 0_f, 3),
- b.Splat(ty.vec3<f32>(), 0_f, 3)));
+ v->SetInitializer(
+ b.Composite(ty.mat2x3<f32>(), b.Splat<vec3<f32>>(0_f), b.Splat<vec3<f32>>(0_f)));
b.Return(func);
});
@@ -234,8 +234,8 @@
auto* func = b.Function("foo", ty.void_());
b.Append(func->Block(), [&] {
auto* v = b.Var("a", ty.ptr<core::AddressSpace::kFunction, mat2x3<f16>>());
- v->SetInitializer(b.Composite(ty.mat2x3<f16>(), b.Splat(ty.vec3<f16>(), 0_h, 3),
- b.Splat(ty.vec3<f16>(), 0_h, 3)));
+ v->SetInitializer(
+ b.Composite(ty.mat2x3<f16>(), b.Splat<vec3<f16>>(0_h), b.Splat<vec3<f16>>(0_h)));
b.Return(func);
});
diff --git a/src/tint/lang/spirv/reader/lower/shader_io_test.cc b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
index d004762..79beb46 100644
--- a/src/tint/lang/spirv/reader/lower/shader_io_test.cc
+++ b/src/tint/lang/spirv/reader/lower/shader_io_test.cc
@@ -1082,7 +1082,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
b.Return(ep);
});
@@ -1137,7 +1137,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
b.Return(ep);
});
@@ -1191,7 +1191,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(ep->Block(), [&] { //
- b.Store(color, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(color, b.Splat<vec4<f32>>(1_f));
b.Return(ep);
});
@@ -1247,7 +1247,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(ep->Block(), [&] { //
- b.Store(color, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(color, b.Splat<vec4<f32>>(1_f));
b.Return(ep);
});
@@ -1318,9 +1318,9 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
- b.Store(color1, b.Splat<vec4<f32>>(0.5_f, 4));
- b.Store(color2, b.Splat<vec4<f32>>(0.25_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
+ b.Store(color1, b.Splat<vec4<f32>>(0.5_f));
+ b.Store(color2, b.Splat<vec4<f32>>(0.25_f));
b.Return(ep);
});
@@ -1413,9 +1413,9 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
auto* ptr = ty.ptr(core::AddressSpace::kOut, ty.vec4<f32>());
- b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f, 4));
- b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f, 4));
- b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f, 4));
+ b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f));
+ b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f));
+ b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f));
b.Return(ep);
});
@@ -1545,9 +1545,9 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
auto* ptr = ty.ptr(core::AddressSpace::kOut, ty.vec4<f32>());
- b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f, 4));
- b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f, 4));
- b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f, 4));
+ b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f));
+ b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f));
+ b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f));
b.Return(ep);
});
@@ -1670,9 +1670,9 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep->Block(), [&] { //
auto* ptr = ty.ptr(core::AddressSpace::kOut, ty.vec4<f32>());
- b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f, 4));
- b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f, 4));
- b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f, 4));
+ b.Store(b.Access(ptr, builtins, 0_u), b.Splat<vec4<f32>>(1_f));
+ b.Store(b.Access(ptr, colors, 0_u), b.Splat<vec4<f32>>(0.5_f));
+ b.Store(b.Access(ptr, colors, 1_u), b.Splat<vec4<f32>>(0.25_f));
b.Return(ep);
});
@@ -1785,21 +1785,21 @@
auto* ep1 = b.Function("main1", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep1->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
b.Return(ep1);
});
auto* ep2 = b.Function("main2", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep2->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
- b.Store(color1, b.Splat<vec4<f32>>(0.5_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
+ b.Store(color1, b.Splat<vec4<f32>>(0.5_f));
b.Return(ep2);
});
auto* ep3 = b.Function("main3", ty.void_(), core::ir::Function::PipelineStage::kVertex);
b.Append(ep3->Block(), [&] { //
- b.Store(position, b.Splat<vec4<f32>>(1_f, 4));
- b.Store(color2, b.Splat<vec4<f32>>(0.25_f, 4));
+ b.Store(position, b.Splat<vec4<f32>>(1_f));
+ b.Store(color2, b.Splat<vec4<f32>>(0.25_f));
b.Return(ep3);
});
@@ -1913,7 +1913,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(ep->Block(), [&] { //
- b.Store(color, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(color, b.Splat<vec4<f32>>(1_f));
auto* load = b.Load(color);
auto* mul = b.Multiply<vec4<f32>>(load, 2_f);
b.Store(color, mul);
@@ -1976,7 +1976,7 @@
auto* ep = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(ep->Block(), [&] { //
- b.Store(color, b.Splat<vec4<f32>>(1_f, 4));
+ b.Store(color, b.Splat<vec4<f32>>(1_f));
auto* load = b.LoadVectorElement(color, 2_u);
auto* mul = b.Multiply<f32>(load, 2_f);
b.StoreVectorElement(color, 2_u, mul);
@@ -2042,7 +2042,7 @@
auto* access_1 = b.Access(ty.ptr(core::AddressSpace::kOut, ty.vec4<f32>()), color);
auto* access_2 = b.Access(ty.ptr(core::AddressSpace::kOut, ty.vec4<f32>()), access_1);
auto* load = b.LoadVectorElement(access_2, 2_u);
- auto* mul = b.Multiply<vec4<f32>>(b.Splat<vec4<f32>>(1_f, 4), load);
+ auto* mul = b.Multiply<vec4<f32>>(b.Splat<vec4<f32>>(1_f), load);
b.Store(access_2, mul);
b.Return(ep);
});
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 73066a6..e1a00e0 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1791,8 +1791,8 @@
if (auto* vec = res_ty->As<core::type::Vector>()) {
// Splat the scalars into vectors.
- one = b_.Splat(vec, one, vec->Width());
- zero = b_.Splat(vec, zero, vec->Width());
+ one = b_.Splat(vec, one);
+ zero = b_.Splat(vec, zero);
}
op = spv::Op::OpSelect;
diff --git a/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
index 1049aac..a215e73 100644
--- a/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
@@ -1466,7 +1466,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSample, t, s, coords,
- b.Splat(ty.vec2<i32>(), 1_i, 2));
+ b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1506,7 +1506,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSample, t, s, coords,
- array_idx, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1588,7 +1588,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleBias, t, s, coords,
- bias, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ bias, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1629,7 +1629,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleBias, t, s, coords,
- array_idx, bias, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, bias, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1710,7 +1710,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.f32(), core::BuiltinFn::kTextureSampleCompare, t, s, coords, dref,
- b.Splat(ty.vec2<i32>(), 1_i, 2));
+ b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1751,7 +1751,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.f32(), core::BuiltinFn::kTextureSampleCompare, t, s, coords,
- array_idx, bias, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, bias, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1833,7 +1833,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.f32(), core::BuiltinFn::kTextureSampleCompareLevel, t, s, coords,
- dref, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ dref, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1874,7 +1874,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.f32(), core::BuiltinFn::kTextureSampleCompareLevel, t, s, coords,
- array_idx, bias, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, bias, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -1958,7 +1958,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleGrad, t, s, coords,
- ddx, ddy, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ ddx, ddy, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2000,7 +2000,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleGrad, t, s, coords,
- array_idx, ddx, ddy, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, ddx, ddy, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2082,7 +2082,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleLevel, t, s, coords,
- lod, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ lod, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2123,7 +2123,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureSampleLevel, t, s, coords,
- array_idx, lod, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ array_idx, lod, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2205,7 +2205,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureGather, component, t, s,
- coords, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ coords, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2246,7 +2246,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureGather, component, t, s,
- coords, array_idx, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ coords, array_idx, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
@@ -2366,7 +2366,7 @@
b.Append(func->Block(), [&] {
auto* result = b.Call(ty.vec4<f32>(), core::BuiltinFn::kTextureGatherCompare, t, s, coords,
- depth, b.Splat(ty.vec2<i32>(), 1_i, 2));
+ depth, b.Splat<vec2<i32>>(1_i));
b.Return(func, result);
});
diff --git a/src/tint/lang/spirv/writer/texture_builtin_test.cc b/src/tint/lang/spirv/writer/texture_builtin_test.cc
index 773dd2f..6c35730 100644
--- a/src/tint/lang/spirv/writer/texture_builtin_test.cc
+++ b/src/tint/lang/spirv/writer/texture_builtin_test.cc
@@ -190,7 +190,7 @@
for (const auto& arg : params.args) {
auto* value = MakeScalarValue(arg.type, arg_value++);
if (arg.width > 1) {
- value = b.Splat(ty.vec(value->Type(), arg.width), value, arg.width);
+ value = b.Splat(ty.vec(value->Type(), arg.width), value);
}
args.Push(value);
mod.SetName(value, arg.name);
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
index 7b9b3cd..766469f 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program_test.cc
@@ -151,7 +151,7 @@
auto* fn = b.Function("f", ty.vec4<f32>(), core::ir::Function::PipelineStage::kVertex);
fn->SetReturnBuiltin(core::BuiltinValue::kPosition);
- fn->Block()->Append(b.Return(fn, b.Splat(ty.vec4<f32>(), 0_f, 4)));
+ fn->Block()->Append(b.Return(fn, b.Splat<vec4<f32>>(0_f)));
EXPECT_WGSL(R"(
@vertex
@@ -194,7 +194,7 @@
fn->SetReturnBuiltin(core::BuiltinValue::kPosition);
fn->SetReturnInvariant(true);
- fn->Block()->Append(b.Return(fn, b.Splat(ty.vec4<f32>(), 0_f, 4)));
+ fn->Block()->Append(b.Return(fn, b.Splat<vec4<f32>>(0_f)));
EXPECT_WGSL(R"(
@vertex
@@ -208,7 +208,7 @@
auto* fn = b.Function("f", ty.vec4<f32>(), core::ir::Function::PipelineStage::kFragment);
fn->SetReturnLocation(1, std::nullopt);
- fn->Block()->Append(b.Return(fn, b.Splat(ty.vec4<f32>(), 0_f, 4)));
+ fn->Block()->Append(b.Return(fn, b.Splat<vec4<f32>>(0_f)));
EXPECT_WGSL(R"(
@fragment