[spirv-reader] Detect constant vector splats

When all operands to an OpConstantComposite for a vector type are the
same, use a vector splat.

Change-Id: I53680e1c8a1589ab0b2047fb6cfdaaae2ba5b654
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/137882
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/reader/spirv/function_composite_test.cc b/src/tint/reader/spirv/function_composite_test.cc
index 21cdbae..3d311f1 100644
--- a/src/tint/reader/spirv/function_composite_test.cc
+++ b/src/tint/reader/spirv/function_composite_test.cc
@@ -138,7 +138,7 @@
     EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = mat3x2f("
                                                                   "vec2f(50.0f, 60.0f), "
                                                                   "vec2f(60.0f, 50.0f), "
-                                                                  "vec2f(70.0f, 70.0f));"));
+                                                                  "vec2f(70.0f));"));
 }
 
 TEST_F(SpvParserTest_Composite_Construct, Array) {
diff --git a/src/tint/reader/spirv/function_memory_test.cc b/src/tint/reader/spirv/function_memory_test.cc
index b75a3bc..a35aa8c 100644
--- a/src/tint/reader/spirv/function_memory_test.cc
+++ b/src/tint/reader/spirv/function_memory_test.cc
@@ -499,8 +499,7 @@
     auto fe = p->function_emitter(100);
     EXPECT_TRUE(fe.EmitBody());
     auto ast_body = fe.ast_body();
-    EXPECT_THAT(test::ToString(p->program(), ast_body),
-                HasSubstr("myvar[2u] = vec4f(42.0f, 42.0f, 42.0f, 42.0f);"));
+    EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("myvar[2u] = vec4f(42.0f);"));
 }
 
 TEST_F(SpvParserMemoryTest, EmitStatement_AccessChain_Array) {
@@ -531,8 +530,7 @@
     auto fe = p->function_emitter(100);
     EXPECT_TRUE(fe.EmitBody());
     auto ast_body = fe.ast_body();
-    EXPECT_THAT(test::ToString(p->program(), ast_body),
-                HasSubstr("myvar[2u] = vec4f(42.0f, 42.0f, 42.0f, 42.0f);"));
+    EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("myvar[2u] = vec4f(42.0f);"));
 }
 
 TEST_F(SpvParserMemoryTest, EmitStatement_AccessChain_Struct) {
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index 971d53c..0c61368 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -1928,6 +1928,8 @@
             }
 
             // Generate a composite from explicit components.
+            bool all_same = true;
+            uint32_t first_id = 0u;
             ExpressionList ast_components;
             if (!inst->WhileEachInId([&](const uint32_t* id_ref) -> bool {
                     auto component = MakeConstantExpression(*id_ref);
@@ -1936,14 +1938,30 @@
                         return false;
                     }
                     ast_components.Push(component.expr);
+
+                    // Check if this argument is different from the others.
+                    if (first_id != 0u) {
+                        if (*id_ref != first_id) {
+                            all_same = false;
+                        }
+                    } else {
+                        first_id = *id_ref;
+                    }
+
                     return true;
                 })) {
                 // We've already emitted a diagnostic.
                 return {};
             }
 
-            auto* expr = builder_.Call(source, original_ast_type->Build(builder_),
-                                       std::move(ast_components));
+            const ast::Expression* expr = nullptr;
+            if (all_same && original_ast_type->Is<Vector>()) {
+                // We're constructing a vector and all the operands were the same, so use a splat.
+                expr = builder_.Call(source, original_ast_type->Build(builder_), ast_components[0]);
+            } else {
+                expr = builder_.Call(source, original_ast_type->Build(builder_),
+                                     std::move(ast_components));
+            }
 
             if (def_use_mgr_->NumUses(id) == 1) {
                 // The constant is only used once, so just inline its use.
diff --git a/src/tint/reader/spirv/parser_impl_handle_test.cc b/src/tint/reader/spirv/parser_impl_handle_test.cc
index 6b519aa..4f1e1a6 100644
--- a/src/tint/reader/spirv/parser_impl_handle_test.cc
+++ b/src/tint/reader/spirv/parser_impl_handle_test.cc
@@ -4171,11 +4171,11 @@
 
   continuing {
     x_27 = (x_26 + 1i);
-    x_24 = vec2f(1.0f, 1.0f);
+    x_24 = vec2f(1.0f);
     x_26 = x_27;
   }
 }
-textureStore(Output2Texture2D, vec2i(vec2u(1u, 1u)), vec4f(x_24, 0.0f, 0.0f));
+textureStore(Output2Texture2D, vec2i(vec2u(1u)), vec4f(x_24, 0.0f, 0.0f));
 return;
 )";
     ASSERT_EQ(expect, got);
diff --git a/test/tint/bug/tint/1520.spvasm.expected.wgsl b/test/tint/bug/tint/1520.spvasm.expected.wgsl
index 1cbb3f3..30d5150 100644
--- a/test/tint/bug/tint/1520.spvasm.expected.wgsl
+++ b/test/tint/bug/tint/1520.spvasm.expected.wgsl
@@ -19,9 +19,9 @@
 
 var<private> vcolor_S0 : vec4f;
 
-const x_46 = vec4i(1i, 1i, 1i, 1i);
+const x_46 = vec4i(1i);
 
-const x_57 = vec4i(2i, 2i, 2i, 2i);
+const x_57 = vec4i(2i);
 
 fn test_int_S1_c0_b() -> bool {
   var unknown : i32;
@@ -76,9 +76,9 @@
   return x_66;
 }
 
-const x_91 = vec4f(1.0f, 1.0f, 1.0f, 1.0f);
+const x_91 = vec4f(1.0f);
 
-const x_102 = vec4f(2.0f, 2.0f, 2.0f, 2.0f);
+const x_102 = vec4f(2.0f);
 
 fn main_1() {
   var outputColor_S0 : vec4f;
diff --git a/test/tint/bug/tint/1932.spvasm.expected.wgsl b/test/tint/bug/tint/1932.spvasm.expected.wgsl
index 5216558..172c292 100644
--- a/test/tint/bug/tint/1932.spvasm.expected.wgsl
+++ b/test/tint/bug/tint/1932.spvasm.expected.wgsl
@@ -1,4 +1,4 @@
-const x_22 = vec2f(2.0f, 2.0f);
+const x_22 = vec2f(2.0f);
 
 fn main_1() {
   let distance_1 = x_22;
diff --git a/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl b/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
index c5a058a..e80cd30 100644
--- a/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
+++ b/test/tint/ptr_ref/access/matrix.spvasm.expected.wgsl
@@ -1,7 +1,7 @@
 fn main_1() {
   var m = mat3x3f();
   m = mat3x3f(vec3f(1.0f, 2.0f, 3.0f), vec3f(4.0f, 5.0f, 6.0f), vec3f(7.0f, 8.0f, 9.0f));
-  m[1i] = vec3f(5.0f, 5.0f, 5.0f);
+  m[1i] = vec3f(5.0f);
   return;
 }