Implement a FoldConstants transform that currently folds scalar and vector conversions
This is required for implementing module-level conversions in the spir-v
backend (upcoming CL).
Bug: tint:865
Change-Id: I7fd38c6b1628c791851917165991bc247fc113c2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54740
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 5de73b2..8899769 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -562,6 +562,8 @@
"transform/external_texture_transform.h",
"transform/first_index_offset.cc",
"transform/first_index_offset.h",
+ "transform/fold_constants.cc",
+ "transform/fold_constants.h",
"transform/inline_pointer_lets.cc",
"transform/inline_pointer_lets.h",
"transform/manager.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 97d8f9d..63a1663 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -287,6 +287,8 @@
transform/external_texture_transform.h
transform/first_index_offset.cc
transform/first_index_offset.h
+ transform/fold_constants.cc
+ transform/fold_constants.h
transform/inline_pointer_lets.cc
transform/inline_pointer_lets.h
transform/manager.cc
@@ -861,6 +863,7 @@
transform/decompose_storage_access_test.cc
transform/external_texture_transform_test.cc
transform/first_index_offset_test.cc
+ transform/fold_constants_test.cc
transform/inline_pointer_lets_test.cc
transform/pad_array_elements_test.cc
transform/promote_initializers_to_const_var_test.cc
diff --git a/src/program_builder.h b/src/program_builder.h
index ae6f5d2..a1c7b46 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -1099,6 +1099,18 @@
}
/// @param args the arguments for the vector constructor
+ /// @param type the vector type
+ /// @param size the vector size
+ /// @return an `ast::TypeConstructorExpression` of a `size`-element vector of
+ /// type `type`, constructed with the values `args`.
+ template <typename... ARGS>
+ ast::TypeConstructorExpression* vec(ast::Type* type,
+ uint32_t size,
+ ARGS&&... args) {
+ return Construct(ty.vec(type, size), std::forward<ARGS>(args)...);
+ }
+
+ /// @param args the arguments for the vector constructor
/// @return an `ast::TypeConstructorExpression` of a 2-element vector of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
diff --git a/src/transform/fold_constants.cc b/src/transform/fold_constants.cc
new file mode 100644
index 0000000..9ef726b
--- /dev/null
+++ b/src/transform/fold_constants.cc
@@ -0,0 +1,353 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/fold_constants.h"
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/program_builder.h"
+
+namespace tint {
+
+namespace {
+
+using i32 = ProgramBuilder::i32;
+using u32 = ProgramBuilder::u32;
+using f32 = ProgramBuilder::f32;
+
+/// A Value is a sequence of scalars
+struct Value {
+ enum class Type {
+ i32, //
+ u32,
+ f32,
+ bool_
+ };
+
+ union Scalar {
+ ProgramBuilder::i32 i32;
+ ProgramBuilder::u32 u32;
+ ProgramBuilder::f32 f32;
+ bool bool_;
+
+ Scalar(ProgramBuilder::i32 v) : i32(v) {} // NOLINT
+ Scalar(ProgramBuilder::u32 v) : u32(v) {} // NOLINT
+ Scalar(ProgramBuilder::f32 v) : f32(v) {} // NOLINT
+ Scalar(bool v) : bool_(v) {} // NOLINT
+ };
+
+ using Elems = std::vector<Scalar>;
+
+ Type type;
+ Elems elems;
+
+ Value() {}
+
+ Value(ProgramBuilder::i32 v) : type(Type::i32), elems{v} {} // NOLINT
+ Value(ProgramBuilder::u32 v) : type(Type::u32), elems{v} {} // NOLINT
+ Value(ProgramBuilder::f32 v) : type(Type::f32), elems{v} {} // NOLINT
+ Value(bool v) : type(Type::bool_), elems{v} {} // NOLINT
+
+ explicit Value(Type t, Elems e = {}) : type(t), elems(std::move(e)) {}
+
+ bool Valid() const { return elems.size() != 0; }
+ operator bool() const { return Valid(); }
+
+ void Append(const Value& value) {
+ TINT_ASSERT(value.type == type);
+ elems.insert(elems.end(), value.elems.begin(), value.elems.end());
+ }
+
+ /// Calls `func`(s) with s being the current scalar value at `index`.
+ /// `func` is typically a lambda of the form '[](auto&& s)'.
+ template <typename Func>
+ auto WithScalarAt(size_t index, Func&& func) const {
+ switch (type) {
+ case Value::Type::i32: {
+ return func(elems[index].i32);
+ }
+ case Value::Type::u32: {
+ return func(elems[index].u32);
+ }
+ case Value::Type::f32: {
+ return func(elems[index].f32);
+ }
+ case Value::Type::bool_: {
+ return func(elems[index].bool_);
+ }
+ }
+ TINT_ASSERT(false && "Unreachable");
+ return func(~0);
+ }
+};
+
+/// Returns the Value::Type that maps to the ast::Type*
+Value::Type AstToValueType(ast::Type* t) {
+ if (t->Is<ast::I32>()) {
+ return Value::Type::i32;
+ } else if (t->Is<ast::U32>()) {
+ return Value::Type::u32;
+ } else if (t->Is<ast::F32>()) {
+ return Value::Type::f32;
+ } else if (t->Is<ast::Bool>()) {
+ return Value::Type::bool_;
+ }
+ TINT_ASSERT(false && "Invalid type");
+ return {};
+}
+
+/// Cast `Value` to `target_type`
+/// @return the casted value
+Value Cast(const Value& value, Value::Type target_type) {
+ if (value.type == target_type) {
+ return value;
+ }
+
+ Value result(target_type);
+ for (size_t i = 0; i < value.elems.size(); ++i) {
+ switch (target_type) {
+ case Value::Type::i32:
+ result.Append(value.WithScalarAt(
+ i, [](auto&& s) { return static_cast<i32>(s); }));
+ break;
+
+ case Value::Type::u32:
+ result.Append(value.WithScalarAt(
+ i, [](auto&& s) { return static_cast<u32>(s); }));
+ break;
+
+ case Value::Type::f32:
+ result.Append(value.WithScalarAt(
+ i, [](auto&& s) { return static_cast<f32>(s); }));
+ break;
+
+ case Value::Type::bool_:
+ result.Append(value.WithScalarAt(
+ i, [](auto&& s) { return static_cast<bool>(s); }));
+ break;
+ }
+ }
+
+ return result;
+}
+
+/// Type that maps `ast::Expression*` to `Value`
+using ExprToValue = std::unordered_map<const ast::Expression*, Value>;
+
+/// Adds mapping of `expr` to `value` to `expr_to_value`
+/// @returns true if add succeded
+bool AddExpr(ExprToValue& expr_to_value,
+ const ast::Expression* expr,
+ Value value) {
+ auto r = expr_to_value.emplace(expr, std::move(value));
+ return r.second;
+}
+
+/// @returns the `Value` in `expr_to_value` at `expr`, leaving it in the map, or
+/// invalid Value if not in map
+Value PeekExpr(ExprToValue& expr_to_value, ast::Expression* expr) {
+ auto iter = expr_to_value.find(expr);
+ if (iter != expr_to_value.end()) {
+ return iter->second;
+ }
+ return {};
+}
+
+/// @returns the `Value` in `expr_to_value` at `expr`, removing it from the map,
+/// or invalid Value if not in map
+Value TakeExpr(ExprToValue& expr_to_value, ast::Expression* expr) {
+ auto iter = expr_to_value.find(expr);
+ if (iter != expr_to_value.end()) {
+ auto result = std::move(iter->second);
+ expr_to_value.erase(iter);
+ return result;
+ }
+ return {};
+}
+
+/// Folds a `ScalarConstructorExpression` into a `Value`
+Value Fold(const ast::ScalarConstructorExpression* scalar_ctor) {
+ auto* literal = scalar_ctor->literal();
+ if (auto* lit = literal->As<ast::SintLiteral>()) {
+ return {lit->value_as_i32()};
+ }
+ if (auto* lit = literal->As<ast::UintLiteral>()) {
+ return {lit->value_as_u32()};
+ }
+ if (auto* lit = literal->As<ast::FloatLiteral>()) {
+ return {lit->value()};
+ }
+ if (auto* lit = literal->As<ast::BoolLiteral>()) {
+ return {lit->IsTrue()};
+ }
+ TINT_ASSERT(false && "Unreachable");
+ return {};
+}
+
+/// Folds a `TypeConstructorExpression` into a `Value` if possible.
+/// @returns a valid `Value` with 1 element for scalars, and 2/3/4 elements for
+/// vectors.
+Value Fold(const ast::TypeConstructorExpression* type_ctor,
+ ExprToValue& expr_to_value) {
+ auto& ctor_values = type_ctor->values();
+ auto* type = type_ctor->type();
+ auto* vec = type->As<ast::Vector>();
+
+ // For now, only fold scalars and vectors
+ if (!type->is_scalar() && !vec) {
+ return {};
+ }
+
+ auto* elem_type = vec ? vec->type() : type;
+ int result_size = vec ? static_cast<int>(vec->size()) : 1;
+
+ // For zero value init, return 0s
+ if (ctor_values.empty()) {
+ if (elem_type->Is<ast::I32>()) {
+ return Value(Value::Type::i32, Value::Elems(result_size, 0));
+ } else if (elem_type->Is<ast::U32>()) {
+ return Value(Value::Type::u32, Value::Elems(result_size, 0u));
+ } else if (elem_type->Is<ast::F32>()) {
+ return Value(Value::Type::f32, Value::Elems(result_size, 0.0f));
+ } else if (elem_type->Is<ast::Bool>()) {
+ return Value(Value::Type::bool_, Value::Elems(result_size, false));
+ }
+ }
+
+ // If not all ctor_values are foldable, we can't fold this node
+ for (auto* cv : ctor_values) {
+ if (!PeekExpr(expr_to_value, cv)) {
+ return {};
+ }
+ }
+
+ // Build value for type_ctor from each child value by casting to
+ // type_ctor's type.
+ Value new_value(AstToValueType(elem_type));
+ for (auto* cv : ctor_values) {
+ auto value = TakeExpr(expr_to_value, cv);
+ new_value.Append(Cast(value, AstToValueType(elem_type)));
+ }
+
+ // Splat single-value initializers
+ if (new_value.elems.size() == 1) {
+ auto first_value = new_value;
+ for (int i = 0; i < result_size - 1; ++i) {
+ new_value.Append(first_value);
+ }
+ }
+
+ return new_value;
+}
+
+/// @returns a `ConstructorExpression` to replace `expr` with, or nullptr if we
+/// shouldn't replace it.
+ast::ConstructorExpression* Build(CloneContext& ctx,
+ const ast::Expression* expr,
+ const Value& value) {
+ // If original ctor expression had no init values, don't replace the
+ // expression
+ if (auto* ctor = expr->As<ast::TypeConstructorExpression>()) {
+ if (ctor->values().size() == 0) {
+ return nullptr;
+ }
+ }
+
+ auto make_ast_type = [&]() -> ast::Type* {
+ switch (value.type) {
+ case Value::Type::i32:
+ return ctx.dst->ty.i32();
+ case Value::Type::u32:
+ return ctx.dst->ty.u32();
+ case Value::Type::f32:
+ return ctx.dst->ty.f32();
+ case Value::Type::bool_:
+ return ctx.dst->ty.bool_();
+ }
+ return nullptr;
+ };
+
+ if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
+ if (auto* vec = type_ctor->type()->As<ast::Vector>()) {
+ uint32_t vec_size = static_cast<uint32_t>(vec->size());
+
+ // We'd like to construct the new vector with the same number of
+ // constructor args that the original node had, but after folding
+ // constants, cases like the following are problematic:
+ //
+ // vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
+ //
+ // In this case, creating a vec3 with 2 args is invalid, so we should
+ // create it with 3. So what we do is construct with vec_size args,
+ // except if the original vector was single-value initialized, in which
+ // case, we only construct with one arg again.
+ uint32_t ctor_size = (type_ctor->values().size() == 1) ? 1 : vec_size;
+
+ ast::ExpressionList ctors;
+ for (uint32_t i = 0; i < ctor_size; ++i) {
+ value.WithScalarAt(
+ i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
+ }
+
+ return ctx.dst->vec(make_ast_type(), vec_size, ctors);
+ } else if (type_ctor->type()->is_scalar()) {
+ return value.WithScalarAt(0, [&](auto&& s) { return ctx.dst->Expr(s); });
+ }
+ }
+ return nullptr;
+}
+
+} // namespace
+
+namespace transform {
+
+FoldConstants::FoldConstants() = default;
+
+FoldConstants::~FoldConstants() = default;
+
+Output FoldConstants::Run(const Program* in, const DataMap&) {
+ ProgramBuilder out;
+ CloneContext ctx(&out, in);
+
+ ExprToValue expr_to_value;
+
+ // Visit inner expressions before outer expressions
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* scalar_ctor = node->As<ast::ScalarConstructorExpression>()) {
+ if (auto v = Fold(scalar_ctor)) {
+ AddExpr(expr_to_value, scalar_ctor, std::move(v));
+ }
+ }
+ if (auto* type_ctor = node->As<ast::TypeConstructorExpression>()) {
+ if (auto v = Fold(type_ctor, expr_to_value)) {
+ AddExpr(expr_to_value, type_ctor, std::move(v));
+ }
+ }
+ }
+
+ for (auto& kvp : expr_to_value) {
+ if (auto* ctor_expr = Build(ctx, kvp.first, kvp.second)) {
+ ctx.Replace(kvp.first, ctor_expr);
+ }
+ }
+
+ ctx.Clone();
+
+ return Output(Program(std::move(out)));
+}
+
+} // namespace transform
+} // namespace tint
diff --git a/src/transform/fold_constants.h b/src/transform/fold_constants.h
new file mode 100644
index 0000000..7e18337
--- /dev/null
+++ b/src/transform/fold_constants.h
@@ -0,0 +1,42 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TRANSFORM_FOLD_CONSTANTS_H_
+#define SRC_TRANSFORM_FOLD_CONSTANTS_H_
+
+#include "src/transform/transform.h"
+
+namespace tint {
+namespace transform {
+
+/// FoldConstants transforms the AST by folding constant expressions
+class FoldConstants : public Transform {
+ public:
+ /// Constructor
+ FoldConstants();
+
+ /// Destructor
+ ~FoldConstants() override;
+
+ /// Runs the transform on `program`, returning the transformation result.
+ /// @param program the source program to transform
+ /// @param data optional extra transform-specific input data
+ /// @returns the transformation result
+ Output Run(const Program* program, const DataMap& data = {}) override;
+};
+
+} // namespace transform
+} // namespace tint
+
+#endif // SRC_TRANSFORM_FOLD_CONSTANTS_H_
diff --git a/src/transform/fold_constants_test.cc b/src/transform/fold_constants_test.cc
new file mode 100644
index 0000000..5736c89
--- /dev/null
+++ b/src/transform/fold_constants_test.cc
@@ -0,0 +1,427 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/fold_constants.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using FoldConstantsTest = TransformTest;
+
+TEST_F(FoldConstantsTest, Module_Scalar_NoConversion) {
+ auto* src = R"(
+var<private> a : i32 = i32(123);
+var<private> b : u32 = u32(123u);
+var<private> c : f32 = f32(123.0);
+var<private> d : bool = bool(true);
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : i32 = 123;
+
+var<private> b : u32 = 123u;
+
+var<private> c : f32 = 123.0;
+
+var<private> d : bool = true;
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Scalar_Conversion) {
+ auto* src = R"(
+var<private> a : i32 = i32(123.0);
+var<private> b : u32 = u32(123);
+var<private> c : f32 = f32(123u);
+var<private> d : bool = bool(123);
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : i32 = 123;
+
+var<private> b : u32 = 123u;
+
+var<private> c : f32 = 123.0;
+
+var<private> d : bool = true;
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Scalar_MultipleConversions) {
+ auto* src = R"(
+var<private> a : i32 = i32(u32(f32(u32(i32(123.0)))));
+var<private> b : u32 = u32(i32(f32(i32(u32(123)))));
+var<private> c : f32 = f32(u32(i32(u32(f32(123u)))));
+var<private> d : bool = bool(i32(f32(i32(u32(123)))));
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : i32 = 123;
+
+var<private> b : u32 = 123u;
+
+var<private> c : f32 = 123.0;
+
+var<private> d : bool = true;
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Vector_NoConversion) {
+ auto* src = R"(
+var<private> a : vec3<i32> = vec3<i32>(123);
+var<private> b : vec3<u32> = vec3<u32>(123u);
+var<private> c : vec3<f32> = vec3<f32>(123.0);
+var<private> d : vec3<bool> = vec3<bool>(true);
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : vec3<i32> = vec3<i32>(123);
+
+var<private> b : vec3<u32> = vec3<u32>(123u);
+
+var<private> c : vec3<f32> = vec3<f32>(123.0);
+
+var<private> d : vec3<bool> = vec3<bool>(true);
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Vector_Conversion) {
+ auto* src = R"(
+var<private> a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
+var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(123));
+var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
+var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(123));
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : vec3<i32> = vec3<i32>(123);
+
+var<private> b : vec3<u32> = vec3<u32>(123u);
+
+var<private> c : vec3<f32> = vec3<f32>(123.0);
+
+var<private> d : vec3<bool> = vec3<bool>(true);
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Vector_MultipleConversions) {
+ auto* src = R"(
+var<private> a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
+var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
+var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
+var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : vec3<i32> = vec3<i32>(123);
+
+var<private> b : vec3<u32> = vec3<u32>(123u);
+
+var<private> c : vec3<f32> = vec3<f32>(123.0);
+
+var<private> d : vec3<bool> = vec3<bool>(true);
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Module_Vector_MixedSizeConversions) {
+ auto* src = R"(
+var<private> a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
+var<private> b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
+var<private> c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
+var<private> d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
+var<private> e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
+
+fn f() {
+}
+)";
+
+ auto* expect = R"(
+var<private> a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+
+var<private> b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
+
+var<private> c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+
+var<private> d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+
+var<private> e : vec4<bool> = vec4<bool>(false, true, false, true);
+
+fn f() {
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Scalar_NoConversion) {
+ auto* src = R"(
+fn f() {
+ var a : i32 = i32(123);
+ var b : u32 = u32(123u);
+ var c : f32 = f32(123.0);
+ var d : bool = bool(true);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : i32 = 123;
+ var b : u32 = 123u;
+ var c : f32 = 123.0;
+ var d : bool = true;
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Scalar_Conversion) {
+ auto* src = R"(
+fn f() {
+ var a : i32 = i32(123.0);
+ var b : u32 = u32(123);
+ var c : f32 = f32(123u);
+ var d : bool = bool(123);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : i32 = 123;
+ var b : u32 = 123u;
+ var c : f32 = 123.0;
+ var d : bool = true;
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Scalar_MultipleConversions) {
+ auto* src = R"(
+fn f() {
+ var a : i32 = i32(u32(f32(u32(i32(123.0)))));
+ var b : u32 = u32(i32(f32(i32(u32(123)))));
+ var c : f32 = f32(u32(i32(u32(f32(123u)))));
+ var d : bool = bool(i32(f32(i32(u32(123)))));
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : i32 = 123;
+ var b : u32 = 123u;
+ var c : f32 = 123.0;
+ var d : bool = true;
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Vector_NoConversion) {
+ auto* src = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(123);
+ var b : vec3<u32> = vec3<u32>(123u);
+ var c : vec3<f32> = vec3<f32>(123.0);
+ var d : vec3<bool> = vec3<bool>(true);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(123);
+ var b : vec3<u32> = vec3<u32>(123u);
+ var c : vec3<f32> = vec3<f32>(123.0);
+ var d : vec3<bool> = vec3<bool>(true);
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Vector_Conversion) {
+ auto* src = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
+ var b : vec3<u32> = vec3<u32>(vec3<i32>(123));
+ var c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
+ var d : vec3<bool> = vec3<bool>(vec3<i32>(123));
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(123);
+ var b : vec3<u32> = vec3<u32>(123u);
+ var c : vec3<f32> = vec3<f32>(123.0);
+ var d : vec3<bool> = vec3<bool>(true);
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Vector_MultipleConversions) {
+ auto* src = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
+ var b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
+ var c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
+ var d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : vec3<i32> = vec3<i32>(123);
+ var b : vec3<u32> = vec3<u32>(123u);
+ var c : vec3<f32> = vec3<f32>(123.0);
+ var d : vec3<bool> = vec3<bool>(true);
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Vector_MixedSizeConversions) {
+ auto* src = R"(
+fn f() {
+ var a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
+ var b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
+ var c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
+ var d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
+ var e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+ var b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
+ var c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+ var d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
+ var e : vec4<bool> = vec4<bool>(false, true, false, true);
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldConstantsTest, Function_Vector_ConstantWithNonConstant) {
+ auto* src = R"(
+fn f() {
+ var a : f32 = f32();
+ var b : vec2<f32> = vec2<f32>(f32(i32(1)), a);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var a : f32 = f32();
+ var b : vec2<f32> = vec2<f32>(1.0, a);
+}
+)";
+
+ auto got = Run<FoldConstants>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace transform
+} // namespace tint