[spirv-reader] Detect constant vector splats
When all operands to an OpConstantComposite for a vector type are the
same, use a vector splat.
Change-Id: I53680e1c8a1589ab0b2047fb6cfdaaae2ba5b654
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/137882
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/reader/spirv/function_composite_test.cc b/src/tint/reader/spirv/function_composite_test.cc
index 21cdbae..3d311f1 100644
--- a/src/tint/reader/spirv/function_composite_test.cc
+++ b/src/tint/reader/spirv/function_composite_test.cc
@@ -138,7 +138,7 @@
EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = mat3x2f("
"vec2f(50.0f, 60.0f), "
"vec2f(60.0f, 50.0f), "
- "vec2f(70.0f, 70.0f));"));
+ "vec2f(70.0f));"));
}
TEST_F(SpvParserTest_Composite_Construct, Array) {
diff --git a/src/tint/reader/spirv/function_memory_test.cc b/src/tint/reader/spirv/function_memory_test.cc
index b75a3bc..a35aa8c 100644
--- a/src/tint/reader/spirv/function_memory_test.cc
+++ b/src/tint/reader/spirv/function_memory_test.cc
@@ -499,8 +499,7 @@
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody());
auto ast_body = fe.ast_body();
- EXPECT_THAT(test::ToString(p->program(), ast_body),
- HasSubstr("myvar[2u] = vec4f(42.0f, 42.0f, 42.0f, 42.0f);"));
+ EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("myvar[2u] = vec4f(42.0f);"));
}
TEST_F(SpvParserMemoryTest, EmitStatement_AccessChain_Array) {
@@ -531,8 +530,7 @@
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody());
auto ast_body = fe.ast_body();
- EXPECT_THAT(test::ToString(p->program(), ast_body),
- HasSubstr("myvar[2u] = vec4f(42.0f, 42.0f, 42.0f, 42.0f);"));
+ EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("myvar[2u] = vec4f(42.0f);"));
}
TEST_F(SpvParserMemoryTest, EmitStatement_AccessChain_Struct) {
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index 971d53c..0c61368 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -1928,6 +1928,8 @@
}
// Generate a composite from explicit components.
+ bool all_same = true;
+ uint32_t first_id = 0u;
ExpressionList ast_components;
if (!inst->WhileEachInId([&](const uint32_t* id_ref) -> bool {
auto component = MakeConstantExpression(*id_ref);
@@ -1936,14 +1938,30 @@
return false;
}
ast_components.Push(component.expr);
+
+ // Check if this argument is different from the others.
+ if (first_id != 0u) {
+ if (*id_ref != first_id) {
+ all_same = false;
+ }
+ } else {
+ first_id = *id_ref;
+ }
+
return true;
})) {
// We've already emitted a diagnostic.
return {};
}
- auto* expr = builder_.Call(source, original_ast_type->Build(builder_),
- std::move(ast_components));
+ const ast::Expression* expr = nullptr;
+ if (all_same && original_ast_type->Is<Vector>()) {
+ // We're constructing a vector and all the operands were the same, so use a splat.
+ expr = builder_.Call(source, original_ast_type->Build(builder_), ast_components[0]);
+ } else {
+ expr = builder_.Call(source, original_ast_type->Build(builder_),
+ std::move(ast_components));
+ }
if (def_use_mgr_->NumUses(id) == 1) {
// The constant is only used once, so just inline its use.
diff --git a/src/tint/reader/spirv/parser_impl_handle_test.cc b/src/tint/reader/spirv/parser_impl_handle_test.cc
index 6b519aa..4f1e1a6 100644
--- a/src/tint/reader/spirv/parser_impl_handle_test.cc
+++ b/src/tint/reader/spirv/parser_impl_handle_test.cc
@@ -4171,11 +4171,11 @@
continuing {
x_27 = (x_26 + 1i);
- x_24 = vec2f(1.0f, 1.0f);
+ x_24 = vec2f(1.0f);
x_26 = x_27;
}
}
-textureStore(Output2Texture2D, vec2i(vec2u(1u, 1u)), vec4f(x_24, 0.0f, 0.0f));
+textureStore(Output2Texture2D, vec2i(vec2u(1u)), vec4f(x_24, 0.0f, 0.0f));
return;
)";
ASSERT_EQ(expect, got);
diff --git a/test/tint/bug/tint/1520.spvasm.expected.wgsl b/test/tint/bug/tint/1520.spvasm.expected.wgsl
index 1cbb3f3..30d5150 100644
--- a/test/tint/bug/tint/1520.spvasm.expected.wgsl
+++ b/test/tint/bug/tint/1520.spvasm.expected.wgsl
@@ -19,9 +19,9 @@
var<private> vcolor_S0 : vec4f;
-const x_46 = vec4i(1i, 1i, 1i, 1i);
+const x_46 = vec4i(1i);
-const x_57 = vec4i(2i, 2i, 2i, 2i);
+const x_57 = vec4i(2i);
fn test_int_S1_c0_b() -> bool {
var unknown : i32;
@@ -76,9 +76,9 @@
return x_66;
}
-const x_91 = vec4f(1.0f, 1.0f, 1.0f, 1.0f);
+const x_91 = vec4f(1.0f);
-const x_102 = vec4f(2.0f, 2.0f, 2.0f, 2.0f);
+const x_102 = vec4f(2.0f);
fn main_1() {
var outputColor_S0 : vec4f;
diff --git a/test/tint/bug/tint/1932.spvasm.expected.wgsl b/test/tint/bug/tint/1932.spvasm.expected.wgsl
index 5216558..172c292 100644
--- a/test/tint/bug/tint/1932.spvasm.expected.wgsl
+++ b/test/tint/bug/tint/1932.spvasm.expected.wgsl
@@ -1,4 +1,4 @@
-const x_22 = vec2f(2.0f, 2.0f);
+const x_22 = vec2f(2.0f);
fn main_1() {
let distance_1 = x_22;
diff --git a/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl b/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
index c5a058a..e80cd30 100644
--- a/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
+++ b/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
@@ -1,7 +1,7 @@
fn main_1() {
var m = mat3x3f();
m = mat3x3f(vec3f(1.0f, 2.0f, 3.0f), vec3f(4.0f, 5.0f, 6.0f), vec3f(7.0f, 8.0f, 9.0f));
- m[1i] = vec3f(5.0f, 5.0f, 5.0f);
+ m[1i] = vec3f(5.0f);
return;
}