tint: Fix constant::Splat conversion of struct types
Conversion can happen for structure materialization (modf, frexp).
If both structure members are the same type and value, then a constant::Splat will be constructed, which needs to handle conversion.
Bug: chromium:1417515
Change-Id: Iadd14ce00b8d5c22226c601ec5af9a84e6c0c5cf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122900
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/constant/composite.h b/src/tint/constant/composite.h
index 3bd4973..5bfa2cb 100644
--- a/src/tint/constant/composite.h
+++ b/src/tint/constant/composite.h
@@ -25,31 +25,40 @@
namespace tint::constant {
/// Composite holds a number of mixed child values.
-/// Composite may be of a vector, matrix or array type.
+/// Composite may be of a vector, matrix, array or structure type.
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate type.
-class Composite : public Castable<Composite, constant::Value> {
+class Composite : public Castable<Composite, Value> {
public:
/// Constructor
/// @param t the compsite type
/// @param els the composite elements
/// @param all_0 true if all elements are 0
/// @param any_0 true if any element is 0
- Composite(const type::Type* t,
- utils::VectorRef<const constant::Value*> els,
- bool all_0,
- bool any_0);
+ Composite(const type::Type* t, utils::VectorRef<const Value*> els, bool all_0, bool any_0);
~Composite() override;
+ /// @copydoc Value::Type()
const type::Type* Type() const override { return type; }
- const constant::Value* Index(size_t i) const override {
+ /// @copydoc Value::Index()
+ const Value* Index(size_t i) const override {
return i < elements.Length() ? elements[i] : nullptr;
}
+ /// @copydoc Value::NumElements()
+ size_t NumElements() const override { return elements.Length(); }
+
+ /// @copydoc Value::AllZero()
bool AllZero() const override { return all_zero; }
+
+ /// @copydoc Value::AnyZero()
bool AnyZero() const override { return any_zero; }
+
+ /// @copydoc Value::AllEqual()
bool AllEqual() const override { return false; }
+
+ /// @copydoc Value::Hash()
size_t Hash() const override { return hash; }
/// Clones the constant into the provided context
@@ -60,7 +69,7 @@
/// The composite type
type::Type const* const type;
/// The composite elements
- const utils::Vector<const constant::Value*, 4> elements;
+ const utils::Vector<const Value*, 4> elements;
/// True if all elements are zero
const bool all_zero;
/// True if any element is zero
@@ -69,6 +78,7 @@
const size_t hash;
protected:
+ /// @copydoc Value::InternalValue()
std::variant<std::monostate, AInt, AFloat> InternalValue() const override { return {}; }
private:
diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h
index a412d91..2a6f6a7 100644
--- a/src/tint/constant/scalar.h
+++ b/src/tint/constant/scalar.h
@@ -25,7 +25,7 @@
/// Scalar holds a single scalar or abstract-numeric value.
template <typename T>
-class Scalar : public Castable<Scalar<T>, constant::Value> {
+class Scalar : public Castable<Scalar<T>, Value> {
public:
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool");
@@ -40,13 +40,25 @@
}
~Scalar() override = default;
+ /// @copydoc Value::Type()
const type::Type* Type() const override { return type; }
- const constant::Value* Index(size_t) const override { return nullptr; }
+ /// @return nullptr, as Scalar does not hold any elements.
+ const Value* Index(size_t) const override { return nullptr; }
+ /// @copydoc Value::NumElements()
+ size_t NumElements() const override { return 1; }
+
+ /// @copydoc Value::AllZero()
bool AllZero() const override { return IsPositiveZero(); }
+
+ /// @copydoc Value::AnyZero()
bool AnyZero() const override { return IsPositiveZero(); }
+
+ /// @copydoc Value::AllEqual()
bool AllEqual() const override { return true; }
+
+ /// @copydoc Value::Hash()
size_t Hash() const override { return utils::Hash(type, ValueOf()); }
/// Clones the constant into the provided context
@@ -79,6 +91,7 @@
const T value;
protected:
+ /// @copydoc Value::InternalValue()
std::variant<std::monostate, AInt, AFloat> InternalValue() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return static_cast<AFloat>(value);
diff --git a/src/tint/constant/splat.h b/src/tint/constant/splat.h
index 5494c8f..4805124 100644
--- a/src/tint/constant/splat.h
+++ b/src/tint/constant/splat.h
@@ -25,14 +25,14 @@
/// Splat holds a single value, duplicated as all children.
///
/// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is
-/// identical. Splat may be of a vector, matrix or array type.
-class Splat : public Castable<Splat, constant::Value> {
+/// identical. Splat may be of a vector, matrix, array or structure type.
+class Splat : public Castable<Splat, Value> {
public:
/// Constructor
/// @param t the splat type
/// @param e the splat element
/// @param n the number of items in the splat
- Splat(const type::Type* t, const constant::Value* e, size_t n);
+ Splat(const type::Type* t, const Value* e, size_t n);
~Splat() override;
/// @returns the type of the splat
@@ -41,7 +41,10 @@
/// Retrieve item at index @p i
/// @param i the index to retrieve
/// @returns the element, or nullptr if out of bounds
- const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; }
+ const Value* Index(size_t i) const override { return i < count ? el : nullptr; }
+
+ /// @copydoc Value::NumElements()
+ size_t NumElements() const override { return count; }
/// @returns true if the element is zero
bool AllZero() const override { return el->AllZero(); }
@@ -61,7 +64,7 @@
/// The type of the splat element
type::Type const* const type;
/// The element stored in the splat
- const constant::Value* el;
+ const Value* el;
/// The number of items in the splat
const size_t count;
diff --git a/src/tint/constant/value.h b/src/tint/constant/value.h
index d5b9876..594091d 100644
--- a/src/tint/constant/value.h
+++ b/src/tint/constant/value.h
@@ -37,6 +37,7 @@
/// @returns the type of the value
virtual const type::Type* Type() const = 0;
+ /// @param i the index of the element
/// @returns the child element with the given index, or nullptr if there are no children, or
/// the index is out of bounds.
///
@@ -44,7 +45,10 @@
/// For vectors, this returns the i'th element of the vector.
/// For matrices, this returns the i'th column vector of the matrix.
/// For structures, this returns the i'th member field of the structure.
- virtual const Value* Index(size_t) const = 0;
+ virtual const Value* Index(size_t i) const = 0;
+
+ /// @return the number of elements held by this Value
+ virtual size_t NumElements() const = 0;
/// @returns true if child elements are positive-zero valued.
virtual bool AllZero() const = 0;
@@ -74,7 +78,7 @@
/// @param b the value to compare too
/// @returns true if this value is equal to @p b
- bool Equal(const constant::Value* b) const;
+ bool Equal(const Value* b) const;
/// Clones the constant into the provided context
/// @param ctx the clone context
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 0ce4f3b..b90ea00 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -326,35 +326,20 @@
const Source& source,
bool use_runtime_semantics);
-ConstEval::Result SplatConvert(const constant::Splat* splat,
- ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source,
- bool use_runtime_semantics) {
- // Convert the single splatted element type.
- auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source,
- use_runtime_semantics);
- if (!conv_el) {
- return utils::Failure;
- }
- if (!conv_el.Get()) {
- return nullptr;
- }
- return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
-}
-
-ConstEval::Result CompositeConvert(const constant::Composite* composite,
+ConstEval::Result CompositeConvert(const constant::Value* value,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source,
bool use_runtime_semantics) {
+ const size_t el_count = value->NumElements();
+
// Convert each of the composite element types.
utils::Vector<const constant::Value*, 4> conv_els;
- conv_els.Reserve(composite->elements.Length());
+ conv_els.Reserve(el_count);
std::function<const type::Type*(size_t idx)> target_el_ty;
if (auto* str = target_ty->As<type::Struct>()) {
- if (TINT_UNLIKELY(str->Members().Length() != composite->elements.Length())) {
+ if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "const-eval conversion of structure has mismatched element counts";
return utils::Failure;
@@ -365,7 +350,8 @@
target_el_ty = [el_ty](size_t) { return el_ty; };
}
- for (auto* el : composite->elements) {
+ for (size_t i = 0; i < el_count; i++) {
+ auto* el = value->Index(i);
auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
use_runtime_semantics);
if (!conv_el) {
@@ -379,6 +365,40 @@
return builder.create<constant::Composite>(target_ty, std::move(conv_els));
}
+ConstEval::Result SplatConvert(const constant::Splat* splat,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source,
+ bool use_runtime_semantics) {
+ const type::Type* target_el_ty = nullptr;
+ if (auto* str = target_ty->As<type::Struct>()) {
+ // Structure conversion.
+ auto members = str->Members();
+ target_el_ty = members[0]->Type();
+
+ // Structures can only be converted during materialization. The user cannot declare the
+ // target structure type, so each member type must be the same default materialization type.
+ for (size_t i = 1; i < members.Length(); i++) {
+ if (members[i]->Type() != target_el_ty) {
+ TINT_ICE(Resolver, builder.Diagnostics())
+ << "inconsistent target struct member types for SplatConvert";
+ return utils::Failure;
+ }
+ }
+ } else {
+ target_el_ty = type::Type::ElementOf(target_ty);
+ }
+ // Convert the single splatted element type.
+ auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics);
+ if (!conv_el) {
+ return utils::Failure;
+ }
+ if (!conv_el.Get()) {
+ return nullptr;
+ }
+ return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
+}
+
ConstEval::Result ConvertInternal(const constant::Value* c,
ProgramBuilder& builder,
const type::Type* target_ty,
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index 92af2de..331eb92 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "src/tint/resolver/const_eval_test.h"
+#include "src/tint/sem/materialize.h"
using namespace tint::number_suffixes; // NOLINT
@@ -472,5 +473,56 @@
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs<AFloat>().value));
}
+TEST_F(ResolverConstEvalTest, StructAbstractSplat_to_StructDifferentTypes) {
+ // fn f() {
+ // const c = modf(4.0);
+ // var v = c;
+ // }
+ auto* expr_c = Call(builtin::Function::kModf, 0_a);
+ auto* materialized = Expr("c");
+ WrapInFunction(Decl(Const("c", expr_c)), Decl(Var("v", materialized)));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* c = Sem().Get(expr_c);
+ ASSERT_NE(c, nullptr);
+ EXPECT_TRUE(c->ConstantValue()->Is<constant::Splat>());
+ EXPECT_TRUE(c->ConstantValue()->AllEqual());
+ EXPECT_TRUE(c->ConstantValue()->AnyZero());
+ EXPECT_TRUE(c->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(c->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(c->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(c->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(c->ConstantValue()->Index(0)->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(c->ConstantValue()->Index(0)->ValueAs<AFloat>(), 0_f);
+
+ EXPECT_TRUE(c->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(c->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(c->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(c->ConstantValue()->Index(1)->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(c->ConstantValue()->Index(1)->ValueAs<AFloat>(), 0_a);
+
+ auto* v = Sem().GetVal(materialized);
+ ASSERT_NE(v, nullptr);
+ EXPECT_TRUE(v->Is<sem::Materialize>());
+ EXPECT_TRUE(v->ConstantValue()->Is<constant::Splat>());
+ EXPECT_TRUE(v->ConstantValue()->AllEqual());
+ EXPECT_TRUE(v->ConstantValue()->AnyZero());
+ EXPECT_TRUE(v->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(v->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_TRUE(v->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_TRUE(v->ConstantValue()->Index(0)->AllZero());
+ EXPECT_TRUE(v->ConstantValue()->Index(0)->Type()->Is<type::F32>());
+ EXPECT_EQ(v->ConstantValue()->Index(0)->ValueAs<f32>(), 0_f);
+
+ EXPECT_TRUE(v->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_TRUE(v->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_TRUE(v->ConstantValue()->Index(1)->AllZero());
+ EXPECT_TRUE(v->ConstantValue()->Index(1)->Type()->Is<type::F32>());
+ EXPECT_EQ(v->ConstantValue()->Index(1)->ValueAs<f32>(), 0_f);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/sem/value_expression_test.cc b/src/tint/sem/value_expression_test.cc
index 1758894..5386ffc 100644
--- a/src/tint/sem/value_expression_test.cc
+++ b/src/tint/sem/value_expression_test.cc
@@ -29,6 +29,7 @@
~MockConstant() override {}
const type::Type* Type() const override { return type; }
const constant::Value* Index(size_t) const override { return {}; }
+ size_t NumElements() const override { return 0; }
bool AllZero() const override { return {}; }
bool AnyZero() const override { return {}; }
bool AllEqual() const override { return {}; }
diff --git a/test/tint/bug/chromium/1417515.wgsl b/test/tint/bug/chromium/1417515.wgsl
new file mode 100644
index 0000000..ee34253
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl
@@ -0,0 +1,3 @@
+fn foo(){
+ let s1 = modf(0.0);
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..8916eb5
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl
@@ -0,0 +1,12 @@
+struct modf_result_f32 {
+ float fract;
+ float whole;
+};
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void foo() {
+ const modf_result_f32 s1 = (modf_result_f32)0;
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..8916eb5
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl
@@ -0,0 +1,12 @@
+struct modf_result_f32 {
+ float fract;
+ float whole;
+};
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void foo() {
+ const modf_result_f32 s1 = (modf_result_f32)0;
+}
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.glsl b/test/tint/bug/chromium/1417515.wgsl.expected.glsl
new file mode 100644
index 0000000..75b85f1
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.glsl
@@ -0,0 +1,16 @@
+#version 310 es
+
+struct modf_result_f32 {
+ float fract;
+ float whole;
+};
+
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+ return;
+}
+void foo() {
+ modf_result_f32 s1 = modf_result_f32(0.0f, 0.0f);
+}
+
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.msl b/test/tint/bug/chromium/1417515.wgsl.expected.msl
new file mode 100644
index 0000000..3433613
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.msl
@@ -0,0 +1,12 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+struct modf_result_f32 {
+ float fract;
+ float whole;
+};
+void foo() {
+ modf_result_f32 const s1 = modf_result_f32{};
+}
+
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.spvasm b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm
new file mode 100644
index 0000000..09a7770
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm
@@ -0,0 +1,29 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 10
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+ OpExecutionMode %unused_entry_point LocalSize 1 1 1
+ OpName %unused_entry_point "unused_entry_point"
+ OpName %foo "foo"
+ OpName %__modf_result_f32 "__modf_result_f32"
+ OpMemberName %__modf_result_f32 0 "fract"
+ OpMemberName %__modf_result_f32 1 "whole"
+ OpMemberDecorate %__modf_result_f32 0 Offset 0
+ OpMemberDecorate %__modf_result_f32 1 Offset 4
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+%__modf_result_f32 = OpTypeStruct %float %float
+ %9 = OpConstantNull %__modf_result_f32
+%unused_entry_point = OpFunction %void None %1
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %foo = OpFunction %void None %1
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.wgsl b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl
new file mode 100644
index 0000000..0750da8
--- /dev/null
+++ b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl
@@ -0,0 +1,3 @@
+fn foo() {
+ let s1 = modf(0.0);
+}