tint: Fix constant::Splat conversion of struct types

Conversion can happen for structure materialization (modf, frexp).
If both structure members are the same type and value, then a constant::Splat will be constructed, which needs to handle conversion.

Bug: chromium:1417515
Change-Id: Iadd14ce00b8d5c22226c601ec5af9a84e6c0c5cf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122900
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/constant/composite.h b/src/tint/constant/composite.h
index 3bd4973..5bfa2cb 100644
--- a/src/tint/constant/composite.h
+++ b/src/tint/constant/composite.h
@@ -25,31 +25,40 @@
 namespace tint::constant {
 
 /// Composite holds a number of mixed child values.
-/// Composite may be of a vector, matrix or array type.
+/// Composite may be of a vector, matrix, array or structure type.
 /// If each element is the same type and value, then a Splat would be a more efficient constant
 /// implementation. Use CreateComposite() to create the appropriate type.
-class Composite : public Castable<Composite, constant::Value> {
+class Composite : public Castable<Composite, Value> {
   public:
     /// Constructor
     /// @param t the compsite type
     /// @param els the composite elements
     /// @param all_0 true if all elements are 0
     /// @param any_0 true if any element is 0
-    Composite(const type::Type* t,
-              utils::VectorRef<const constant::Value*> els,
-              bool all_0,
-              bool any_0);
+    Composite(const type::Type* t, utils::VectorRef<const Value*> els, bool all_0, bool any_0);
     ~Composite() override;
 
+    /// @copydoc Value::Type()
     const type::Type* Type() const override { return type; }
 
-    const constant::Value* Index(size_t i) const override {
+    /// @copydoc Value::Index()
+    const Value* Index(size_t i) const override {
         return i < elements.Length() ? elements[i] : nullptr;
     }
 
+    /// @copydoc Value::NumElements()
+    size_t NumElements() const override { return elements.Length(); }
+
+    /// @copydoc Value::AllZero()
     bool AllZero() const override { return all_zero; }
+
+    /// @copydoc Value::AnyZero()
     bool AnyZero() const override { return any_zero; }
+
+    /// @copydoc Value::AllEqual()
     bool AllEqual() const override { return false; }
+
+    /// @copydoc Value::Hash()
     size_t Hash() const override { return hash; }
 
     /// Clones the constant into the provided context
@@ -60,7 +69,7 @@
     /// The composite type
     type::Type const* const type;
     /// The composite elements
-    const utils::Vector<const constant::Value*, 4> elements;
+    const utils::Vector<const Value*, 4> elements;
     /// True if all elements are zero
     const bool all_zero;
     /// True if any element is zero
@@ -69,6 +78,7 @@
     const size_t hash;
 
   protected:
+    /// @copydoc Value::InternalValue()
     std::variant<std::monostate, AInt, AFloat> InternalValue() const override { return {}; }
 
   private:
diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h
index a412d91..2a6f6a7 100644
--- a/src/tint/constant/scalar.h
+++ b/src/tint/constant/scalar.h
@@ -25,7 +25,7 @@
 
 /// Scalar holds a single scalar or abstract-numeric value.
 template <typename T>
-class Scalar : public Castable<Scalar<T>, constant::Value> {
+class Scalar : public Castable<Scalar<T>, Value> {
   public:
     static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
                   "T must be a Number or bool");
@@ -40,13 +40,25 @@
     }
     ~Scalar() override = default;
 
+    /// @copydoc Value::Type()
     const type::Type* Type() const override { return type; }
 
-    const constant::Value* Index(size_t) const override { return nullptr; }
+    /// @return nullptr, as Scalar does not hold any elements.
+    const Value* Index(size_t) const override { return nullptr; }
 
+    /// @copydoc Value::NumElements()
+    size_t NumElements() const override { return 1; }
+
+    /// @copydoc Value::AllZero()
     bool AllZero() const override { return IsPositiveZero(); }
+
+    /// @copydoc Value::AnyZero()
     bool AnyZero() const override { return IsPositiveZero(); }
+
+    /// @copydoc Value::AllEqual()
     bool AllEqual() const override { return true; }
+
+    /// @copydoc Value::Hash()
     size_t Hash() const override { return utils::Hash(type, ValueOf()); }
 
     /// Clones the constant into the provided context
@@ -79,6 +91,7 @@
     const T value;
 
   protected:
+    /// @copydoc Value::InternalValue()
     std::variant<std::monostate, AInt, AFloat> InternalValue() const override {
         if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
             return static_cast<AFloat>(value);
diff --git a/src/tint/constant/splat.h b/src/tint/constant/splat.h
index 5494c8f..4805124 100644
--- a/src/tint/constant/splat.h
+++ b/src/tint/constant/splat.h
@@ -25,14 +25,14 @@
 /// Splat holds a single value, duplicated as all children.
 ///
 /// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is
-/// identical. Splat may be of a vector, matrix or array type.
-class Splat : public Castable<Splat, constant::Value> {
+/// identical. Splat may be of a vector, matrix, array or structure type.
+class Splat : public Castable<Splat, Value> {
   public:
     /// Constructor
     /// @param t the splat type
     /// @param e the splat element
     /// @param n the number of items in the splat
-    Splat(const type::Type* t, const constant::Value* e, size_t n);
+    Splat(const type::Type* t, const Value* e, size_t n);
     ~Splat() override;
 
     /// @returns the type of the splat
@@ -41,7 +41,10 @@
     /// Retrieve item at index @p i
     /// @param i the index to retrieve
     /// @returns the element, or nullptr if out of bounds
-    const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; }
+    const Value* Index(size_t i) const override { return i < count ? el : nullptr; }
+
+    /// @copydoc Value::NumElements()
+    size_t NumElements() const override { return count; }
 
     /// @returns true if the element is zero
     bool AllZero() const override { return el->AllZero(); }
@@ -61,7 +64,7 @@
     /// The type of the splat element
     type::Type const* const type;
     /// The element stored in the splat
-    const constant::Value* el;
+    const Value* el;
     /// The number of items in the splat
     const size_t count;
 
diff --git a/src/tint/constant/value.h b/src/tint/constant/value.h
index d5b9876..594091d 100644
--- a/src/tint/constant/value.h
+++ b/src/tint/constant/value.h
@@ -37,6 +37,7 @@
     /// @returns the type of the value
     virtual const type::Type* Type() const = 0;
 
+    /// @param i the index of the element
     /// @returns the child element with the given index, or nullptr if there are no children, or
     /// the index is out of bounds.
     ///
@@ -44,7 +45,10 @@
     /// For vectors, this returns the i'th element of the vector.
     /// For matrices, this returns the i'th column vector of the matrix.
     /// For structures, this returns the i'th member field of the structure.
-    virtual const Value* Index(size_t) const = 0;
+    virtual const Value* Index(size_t i) const = 0;
+
+    /// @return the number of elements held by this Value
+    virtual size_t NumElements() const = 0;
 
     /// @returns true if child elements are positive-zero valued.
     virtual bool AllZero() const = 0;
@@ -74,7 +78,7 @@
 
     /// @param b the value to compare too
     /// @returns true if this value is equal to @p b
-    bool Equal(const constant::Value* b) const;
+    bool Equal(const Value* b) const;
 
     /// Clones the constant into the provided context
     /// @param ctx the clone context
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 0ce4f3b..b90ea00 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -326,35 +326,20 @@
                                   const Source& source,
                                   bool use_runtime_semantics);
 
-ConstEval::Result SplatConvert(const constant::Splat* splat,
-                               ProgramBuilder& builder,
-                               const type::Type* target_ty,
-                               const Source& source,
-                               bool use_runtime_semantics) {
-    // Convert the single splatted element type.
-    auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source,
-                                   use_runtime_semantics);
-    if (!conv_el) {
-        return utils::Failure;
-    }
-    if (!conv_el.Get()) {
-        return nullptr;
-    }
-    return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
-}
-
-ConstEval::Result CompositeConvert(const constant::Composite* composite,
+ConstEval::Result CompositeConvert(const constant::Value* value,
                                    ProgramBuilder& builder,
                                    const type::Type* target_ty,
                                    const Source& source,
                                    bool use_runtime_semantics) {
+    const size_t el_count = value->NumElements();
+
     // Convert each of the composite element types.
     utils::Vector<const constant::Value*, 4> conv_els;
-    conv_els.Reserve(composite->elements.Length());
+    conv_els.Reserve(el_count);
 
     std::function<const type::Type*(size_t idx)> target_el_ty;
     if (auto* str = target_ty->As<type::Struct>()) {
-        if (TINT_UNLIKELY(str->Members().Length() != composite->elements.Length())) {
+        if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
             TINT_ICE(Resolver, builder.Diagnostics())
                 << "const-eval conversion of structure has mismatched element counts";
             return utils::Failure;
@@ -365,7 +350,8 @@
         target_el_ty = [el_ty](size_t) { return el_ty; };
     }
 
-    for (auto* el : composite->elements) {
+    for (size_t i = 0; i < el_count; i++) {
+        auto* el = value->Index(i);
         auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
                                        use_runtime_semantics);
         if (!conv_el) {
@@ -379,6 +365,40 @@
     return builder.create<constant::Composite>(target_ty, std::move(conv_els));
 }
 
+ConstEval::Result SplatConvert(const constant::Splat* splat,
+                               ProgramBuilder& builder,
+                               const type::Type* target_ty,
+                               const Source& source,
+                               bool use_runtime_semantics) {
+    const type::Type* target_el_ty = nullptr;
+    if (auto* str = target_ty->As<type::Struct>()) {
+        // Structure conversion.
+        auto members = str->Members();
+        target_el_ty = members[0]->Type();
+
+        // Structures can only be converted during materialization. The user cannot declare the
+        // target structure type, so each member type must be the same default materialization type.
+        for (size_t i = 1; i < members.Length(); i++) {
+            if (members[i]->Type() != target_el_ty) {
+                TINT_ICE(Resolver, builder.Diagnostics())
+                    << "inconsistent target struct member types for SplatConvert";
+                return utils::Failure;
+            }
+        }
+    } else {
+        target_el_ty = type::Type::ElementOf(target_ty);
+    }
+    // Convert the single splatted element type.
+    auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics);
+    if (!conv_el) {
+        return utils::Failure;
+    }
+    if (!conv_el.Get()) {
+        return nullptr;
+    }
+    return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
+}
+
 ConstEval::Result ConvertInternal(const constant::Value* c,
                                   ProgramBuilder& builder,
                                   const type::Type* target_ty,
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index 92af2de..331eb92 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "src/tint/resolver/const_eval_test.h"
+#include "src/tint/sem/materialize.h"
 
 using namespace tint::number_suffixes;  // NOLINT
 
@@ -472,5 +473,56 @@
     EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs<AFloat>().value));
 }
 
+TEST_F(ResolverConstEvalTest, StructAbstractSplat_to_StructDifferentTypes) {
+    // fn f() {
+    //   const c = modf(4.0);
+    //   var v = c;
+    // }
+    auto* expr_c = Call(builtin::Function::kModf, 0_a);
+    auto* materialized = Expr("c");
+    WrapInFunction(Decl(Const("c", expr_c)), Decl(Var("v", materialized)));
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto* c = Sem().Get(expr_c);
+    ASSERT_NE(c, nullptr);
+    EXPECT_TRUE(c->ConstantValue()->Is<constant::Splat>());
+    EXPECT_TRUE(c->ConstantValue()->AllEqual());
+    EXPECT_TRUE(c->ConstantValue()->AnyZero());
+    EXPECT_TRUE(c->ConstantValue()->AllZero());
+
+    EXPECT_TRUE(c->ConstantValue()->Index(0)->AllEqual());
+    EXPECT_TRUE(c->ConstantValue()->Index(0)->AnyZero());
+    EXPECT_TRUE(c->ConstantValue()->Index(0)->AllZero());
+    EXPECT_TRUE(c->ConstantValue()->Index(0)->Type()->Is<type::AbstractFloat>());
+    EXPECT_EQ(c->ConstantValue()->Index(0)->ValueAs<AFloat>(), 0_f);
+
+    EXPECT_TRUE(c->ConstantValue()->Index(1)->AllEqual());
+    EXPECT_TRUE(c->ConstantValue()->Index(1)->AnyZero());
+    EXPECT_TRUE(c->ConstantValue()->Index(1)->AllZero());
+    EXPECT_TRUE(c->ConstantValue()->Index(1)->Type()->Is<type::AbstractFloat>());
+    EXPECT_EQ(c->ConstantValue()->Index(1)->ValueAs<AFloat>(), 0_a);
+
+    auto* v = Sem().GetVal(materialized);
+    ASSERT_NE(v, nullptr);
+    EXPECT_TRUE(v->Is<sem::Materialize>());
+    EXPECT_TRUE(v->ConstantValue()->Is<constant::Splat>());
+    EXPECT_TRUE(v->ConstantValue()->AllEqual());
+    EXPECT_TRUE(v->ConstantValue()->AnyZero());
+    EXPECT_TRUE(v->ConstantValue()->AllZero());
+
+    EXPECT_TRUE(v->ConstantValue()->Index(0)->AllEqual());
+    EXPECT_TRUE(v->ConstantValue()->Index(0)->AnyZero());
+    EXPECT_TRUE(v->ConstantValue()->Index(0)->AllZero());
+    EXPECT_TRUE(v->ConstantValue()->Index(0)->Type()->Is<type::F32>());
+    EXPECT_EQ(v->ConstantValue()->Index(0)->ValueAs<f32>(), 0_f);
+
+    EXPECT_TRUE(v->ConstantValue()->Index(1)->AllEqual());
+    EXPECT_TRUE(v->ConstantValue()->Index(1)->AnyZero());
+    EXPECT_TRUE(v->ConstantValue()->Index(1)->AllZero());
+    EXPECT_TRUE(v->ConstantValue()->Index(1)->Type()->Is<type::F32>());
+    EXPECT_EQ(v->ConstantValue()->Index(1)->ValueAs<f32>(), 0_f);
+}
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/sem/value_expression_test.cc b/src/tint/sem/value_expression_test.cc
index 1758894..5386ffc 100644
--- a/src/tint/sem/value_expression_test.cc
+++ b/src/tint/sem/value_expression_test.cc
@@ -29,6 +29,7 @@
     ~MockConstant() override {}
     const type::Type* Type() const override { return type; }
     const constant::Value* Index(size_t) const override { return {}; }
+    size_t NumElements() const override { return 0; }
     bool AllZero() const override { return {}; }
     bool AnyZero() const override { return {}; }
     bool AllEqual() const override { return {}; }
diff --git a/test/tint/bug/chromium/1417515.wgsl b/test/tint/bug/chromium/1417515.wgsl
new file mode 100644
index 0000000..ee34253
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl
@@ -0,0 +1,3 @@
+fn foo(){
+  let s1 = modf(0.0);
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..8916eb5
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl
@@ -0,0 +1,12 @@
+struct modf_result_f32 {
+  float fract;
+  float whole;
+};
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void foo() {
+  const modf_result_f32 s1 = (modf_result_f32)0;
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..8916eb5
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl
@@ -0,0 +1,12 @@
+struct modf_result_f32 {
+  float fract;
+  float whole;
+};
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void foo() {
+  const modf_result_f32 s1 = (modf_result_f32)0;
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.glsl b/test/tint/bug/chromium/1417515.wgsl.expected.glsl
new file mode 100644
index 0000000..75b85f1
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.glsl
@@ -0,0 +1,16 @@
+#version 310 es
+
+struct modf_result_f32 {
+  float fract;
+  float whole;
+};
+
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+  return;
+}
+void foo() {
+  modf_result_f32 s1 = modf_result_f32(0.0f, 0.0f);
+}
+
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.msl b/test/tint/bug/chromium/1417515.wgsl.expected.msl
new file mode 100644
index 0000000..3433613
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.msl
@@ -0,0 +1,12 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+struct modf_result_f32 {
+  float fract;
+  float whole;
+};
+void foo() {
+  modf_result_f32 const s1 = modf_result_f32{};
+}
+
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.spvasm b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm
new file mode 100644
index 0000000..09a7770
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm
@@ -0,0 +1,29 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 10
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %foo "foo"
+               OpName %__modf_result_f32 "__modf_result_f32"
+               OpMemberName %__modf_result_f32 0 "fract"
+               OpMemberName %__modf_result_f32 1 "whole"
+               OpMemberDecorate %__modf_result_f32 0 Offset 0
+               OpMemberDecorate %__modf_result_f32 1 Offset 4
+       %void = OpTypeVoid
+          %1 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+%__modf_result_f32 = OpTypeStruct %float %float
+          %9 = OpConstantNull %__modf_result_f32
+%unused_entry_point = OpFunction %void None %1
+          %4 = OpLabel
+               OpReturn
+               OpFunctionEnd
+        %foo = OpFunction %void None %1
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.wgsl b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl
new file mode 100644
index 0000000..0750da8
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl
@@ -0,0 +1,3 @@
+fn foo() {
+  let s1 = modf(0.0);
+}