Import Tint changes from Dawn
Changes:
- 426b47e4816542659fdce29b3c689045d288e199 tint: add missing F16 conversion expression support by Zhaoming Jiang <zhaoming.jiang@intel.com>
- 0c7c23b9c4b84d0c241b4e970116d7f89c3fa4b6 tint: Misc hash / container contract improvements by Ben Clayton <bclayton@google.com>
- 6a17e33f3dafa0e9d6c9109c53223fc7b89c83af tint/writer: Remove spirv::Operand hasher by Ben Clayton <bclayton@google.com>
- b04d992f8395c812aacc727f0fa1150bddbe5c49 tint/utils: Fix Hashmap::GetOrCreate() for map mutation i... by Ben Clayton <bclayton@google.com>
- b6d524380ef9c3311ef999e59884bb040273360f tint: Improve resolver test helper to specify more than o... by Antonio Maiorano <amaiorano@google.com>
- cd716e6f01ae658501ad5471bac67926386398eb tint::CloneContext: Use Hashmap::Generation() by Ben Clayton <bclayton@google.com>
- 4e0335c5afec9d64203a9e43d83fd1ebc0fe29f7 tint/utils: Add Generation() to Hashmap and Hashset. by Ben Clayton <bclayton@google.com>
- ff0295ebd88c3bb0343f01404d991a06dd8fa7e8 tint: Fix AInt -> AFloat implicit conversion from constru... by Antonio Maiorano <amaiorano@google.com>
- 5361d9e778fb839c0e63adcc86edae267c5b3b44 Convert @id to an expression. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 426b47e4816542659fdce29b3c689045d288e199
Change-Id: I805d4ad728f7bf0a7baf11707496c6d07cec14b0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/100960
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index d140e1e..61bdea6 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -547,6 +547,8 @@
"transform/utils/hoist_to_decl_before.h",
"transform/var_for_dynamic_index.cc",
"transform/var_for_dynamic_index.h",
+ "transform/vectorize_matrix_conversions.cc",
+ "transform/vectorize_matrix_conversions.h",
"transform/vectorize_scalar_matrix_constructors.cc",
"transform/vectorize_scalar_matrix_constructors.h",
"transform/vertex_pulling.cc",
@@ -1043,7 +1045,6 @@
"ast/module_clone_test.cc",
"ast/module_test.cc",
"ast/multisampled_texture_test.cc",
- "ast/override_test.cc",
"ast/phony_expression_test.cc",
"ast/pointer_test.cc",
"ast/return_statement_test.cc",
@@ -1220,6 +1221,7 @@
"transform/utils/get_insertion_point_test.cc",
"transform/utils/hoist_to_decl_before_test.cc",
"transform/var_for_dynamic_index_test.cc",
+ "transform/vectorize_matrix_conversions_test.cc",
"transform/vectorize_scalar_matrix_constructors_test.cc",
"transform/vertex_pulling_test.cc",
"transform/while_to_loop_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 4ffaf94..62809ab 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -459,6 +459,8 @@
transform/utils/hoist_to_decl_before.h
transform/var_for_dynamic_index.cc
transform/var_for_dynamic_index.h
+ transform/vectorize_matrix_conversions.cc
+ transform/vectorize_matrix_conversions.h
transform/vectorize_scalar_matrix_constructors.cc
transform/vectorize_scalar_matrix_constructors.h
transform/vertex_pulling.cc
@@ -736,7 +738,6 @@
ast/module_clone_test.cc
ast/module_test.cc
ast/multisampled_texture_test.cc
- ast/override_test.cc
ast/phony_expression_test.cc
ast/pointer_test.cc
ast/return_statement_test.cc
@@ -1132,6 +1133,7 @@
transform/unshadow_test.cc
transform/unwind_discard_functions_test.cc
transform/var_for_dynamic_index_test.cc
+ transform/vectorize_matrix_conversions_test.cc
transform/vectorize_scalar_matrix_constructors_test.cc
transform/vertex_pulling_test.cc
transform/while_to_loop_test.cc
diff --git a/src/tint/ast/id_attribute.cc b/src/tint/ast/id_attribute.cc
index 75d62c6..9c1d1ae 100644
--- a/src/tint/ast/id_attribute.cc
+++ b/src/tint/ast/id_attribute.cc
@@ -22,7 +22,7 @@
namespace tint::ast {
-IdAttribute::IdAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t val)
+IdAttribute::IdAttribute(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* val)
: Base(pid, nid, src), value(val) {}
IdAttribute::~IdAttribute() = default;
@@ -34,7 +34,8 @@
const IdAttribute* IdAttribute::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
- return ctx->dst->create<IdAttribute>(src, value);
+ auto* value_ = ctx->Clone(value);
+ return ctx->dst->create<IdAttribute>(src, value_);
}
} // namespace tint::ast
diff --git a/src/tint/ast/id_attribute.h b/src/tint/ast/id_attribute.h
index ca2a358..f707bde 100644
--- a/src/tint/ast/id_attribute.h
+++ b/src/tint/ast/id_attribute.h
@@ -18,6 +18,7 @@
#include <string>
#include "src/tint/ast/attribute.h"
+#include "src/tint/ast/expression.h"
namespace tint::ast {
@@ -28,8 +29,8 @@
/// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier
/// @param src the source of this node
- /// @param val the numeric id value
- IdAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t val);
+ /// @param val the numeric id value expression
+ IdAttribute(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* val);
~IdAttribute() override;
/// @returns the WGSL name for the attribute
@@ -41,8 +42,8 @@
/// @return the newly cloned node
const IdAttribute* Clone(CloneContext* ctx) const override;
- /// The id value
- const uint32_t value;
+ /// The id value expression
+ const ast::Expression* const value;
};
} // namespace tint::ast
diff --git a/src/tint/ast/id_attribute_test.cc b/src/tint/ast/id_attribute_test.cc
index ad05c58..84605b1 100644
--- a/src/tint/ast/id_attribute_test.cc
+++ b/src/tint/ast/id_attribute_test.cc
@@ -19,11 +19,12 @@
namespace tint::ast {
namespace {
+using namespace tint::number_suffixes; // NOLINT
using IdAttributeTest = TestHelper;
TEST_F(IdAttributeTest, Creation) {
- auto* d = create<IdAttribute>(12u);
- EXPECT_EQ(12u, d->value);
+ auto* d = Id(12_a);
+ EXPECT_TRUE(d->value->Is<ast::IntLiteralExpression>());
}
} // namespace
diff --git a/src/tint/ast/override.cc b/src/tint/ast/override.cc
index 3a0a3a9..bb0f063 100644
--- a/src/tint/ast/override.cc
+++ b/src/tint/ast/override.cc
@@ -48,11 +48,4 @@
return ctx->dst->create<Override>(src, sym, ty, ctor, std::move(attrs));
}
-std::string Override::Identifier(const SymbolTable& symbols) const {
- if (auto* id = ast::GetAttribute<ast::IdAttribute>(attributes)) {
- return std::to_string(id->value);
- }
- return symbols.NameFor(symbol);
-}
-
} // namespace tint::ast
diff --git a/src/tint/ast/override.h b/src/tint/ast/override.h
index 7d01d13..c56cedf 100644
--- a/src/tint/ast/override.h
+++ b/src/tint/ast/override.h
@@ -62,12 +62,6 @@
/// @param ctx the clone context
/// @return the newly cloned node
const Override* Clone(CloneContext* ctx) const override;
-
- /// @param symbols the symbol table to retrieve the name from
- /// @returns the identifier string for the override. If the override has
- /// an ID attribute, the string is the id-stringified. Otherwise, the ID
- /// is the symbol.
- std::string Identifier(const SymbolTable& symbols) const;
};
} // namespace tint::ast
diff --git a/src/tint/ast/override_test.cc b/src/tint/ast/override_test.cc
deleted file mode 100644
index f037601..0000000
--- a/src/tint/ast/override_test.cc
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2022 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/ast/override.h"
-
-#include "src/tint/ast/test_helper.h"
-
-namespace tint::ast {
-namespace {
-
-using OverrideTest = TestHelper;
-
-TEST_F(OverrideTest, Identifier_NoId) {
- auto* o = Override("o", Expr(f32(1.0)));
- EXPECT_EQ(std::string("o"), o->Identifier(Symbols()));
-}
-
-TEST_F(OverrideTest, Identifier_WithId) {
- auto* o = Override("o", Expr(f32(1.0)), Id(4u));
- EXPECT_EQ(std::string("4"), o->Identifier(Symbols()));
-}
-
-} // namespace
-} // namespace tint::ast
diff --git a/src/tint/ast/variable_test.cc b/src/tint/ast/variable_test.cc
index e67336d..40dd68d 100644
--- a/src/tint/ast/variable_test.cc
+++ b/src/tint/ast/variable_test.cc
@@ -93,7 +93,7 @@
TEST_F(VariableTest, WithAttributes) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Location(1u),
- Builtin(BuiltinValue::kPosition), Id(1200u));
+ Builtin(BuiltinValue::kPosition), Id(1200_u));
auto& attributes = var->attributes;
EXPECT_TRUE(ast::HasAttribute<ast::LocationAttribute>(attributes));
diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h
index 3c5f6ec..8e045f5 100644
--- a/src/tint/clone_context.h
+++ b/src/tint/clone_context.h
@@ -541,6 +541,8 @@
/// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
struct VectorListTransforms {
+ using Map = utils::Hashmap<const void*, ListTransforms, 4>;
+
/// An accessor to the VectorListTransforms map.
/// Index caches the last map lookup, and will only re-search the map if the transform map
/// was modified since the last lookup.
@@ -560,44 +562,36 @@
private:
friend VectorListTransforms;
- Index(const void* list,
- VectorListTransforms& vlt,
- uint32_t generation,
- const ListTransforms* cached)
- : list_(list), vlt_(vlt), generation_(generation), cached_(cached) {}
+ Index(const void* list, Map* map)
+ : list_(list),
+ map_(map),
+ generation_(map->Generation()),
+ cached_(map_->Find(list)) {}
void Update() {
- if (vlt_.generation_ != generation_) {
- cached_ = vlt_.map_.Find(list_);
- generation_ = vlt_.generation_;
+ if (map_->Generation() != generation_) {
+ cached_ = map_->Find(list_);
+ generation_ = map_->Generation();
}
}
const void* list_;
- VectorListTransforms& vlt_;
- uint32_t generation_;
+ Map* map_;
+ uint64_t generation_;
const ListTransforms* cached_;
};
/// Edit returns a reference to the ListTransforms for the given vector pointer and
/// increments #list_transform_generation_ signalling that the list transforms have been
/// modified.
- inline ListTransforms& Edit(const void* list) {
- generation_++;
- return map_.GetOrZero(list);
- }
+ inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); }
/// @returns an Index to the transforms for the given list.
- inline Index Find(const void* list) {
- return Index{list, *this, generation_, map_.Find(list)};
- }
+ inline Index Find(const void* list) { return Index{list, &map_}; }
private:
/// The map of vector pointer to ListTransforms
- utils::Hashmap<const void*, ListTransforms, 4> map_;
-
- /// A counter that's incremented each time list transforms are modified.
- uint32_t generation_ = 0;
+ Map map_;
};
/// A map of object in #src to functions that create their replacement in #dst
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 5f0013f..5bfd08e 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -697,8 +697,8 @@
}
TEST_F(InspectorGetEntryPointTest, OverrideSomeReferenced) {
- Override("foo", ty.f32(), Id(1));
- Override("bar", ty.f32(), Id(2));
+ Override("foo", ty.f32(), Id(1_a));
+ Override("bar", ty.f32(), Id(2_a));
MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), utils::Empty);
MakeCallerBodyFunction("ep_func", utils::Vector{std::string("callee_func")},
utils::Vector{
@@ -789,7 +789,7 @@
TEST_F(InspectorGetEntryPointTest, OverrideNumericIDSpecified) {
Override("foo_no_id", ty.f32());
- Override("foo_id", ty.f32(), Id(1234));
+ Override("foo_id", ty.f32(), Id(1234_a));
MakePlainGlobalReferenceBodyFunction("no_id_func", "foo_no_id", ty.f32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("id_func", "foo_id", ty.f32(), utils::Empty);
@@ -1225,9 +1225,9 @@
InterpolationType::kFlat, InterpolationSampling::kNone}));
TEST_F(InspectorGetOverrideDefaultValuesTest, Bool) {
- Override("foo", ty.bool_(), Id(1));
- Override("bar", ty.bool_(), Expr(true), Id(20));
- Override("baz", ty.bool_(), Expr(false), Id(300));
+ Override("foo", ty.bool_(), Id(1_a));
+ Override("bar", ty.bool_(), Expr(true), Id(20_a));
+ Override("baz", ty.bool_(), Expr(false), Id(300_a));
Inspector& inspector = Build();
@@ -1247,8 +1247,8 @@
}
TEST_F(InspectorGetOverrideDefaultValuesTest, U32) {
- Override("foo", ty.u32(), Id(1));
- Override("bar", ty.u32(), Expr(42_u), Id(20));
+ Override("foo", ty.u32(), Id(1_a));
+ Override("bar", ty.u32(), Expr(42_u), Id(20_a));
Inspector& inspector = Build();
@@ -1264,9 +1264,9 @@
}
TEST_F(InspectorGetOverrideDefaultValuesTest, I32) {
- Override("foo", ty.i32(), Id(1));
- Override("bar", ty.i32(), Expr(-42_i), Id(20));
- Override("baz", ty.i32(), Expr(42_i), Id(300));
+ Override("foo", ty.i32(), Id(1_a));
+ Override("bar", ty.i32(), Expr(-42_i), Id(20_a));
+ Override("baz", ty.i32(), Expr(42_i), Id(300_a));
Inspector& inspector = Build();
@@ -1286,10 +1286,10 @@
}
TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
- Override("foo", ty.f32(), Id(1));
- Override("bar", ty.f32(), Expr(0_f), Id(20));
- Override("baz", ty.f32(), Expr(-10_f), Id(300));
- Override("x", ty.f32(), Expr(15_f), Id(4000));
+ Override("foo", ty.f32(), Id(1_a));
+ Override("bar", ty.f32(), Expr(0_f), Id(20_a));
+ Override("baz", ty.f32(), Expr(-10_f), Id(300_a));
+ Override("x", ty.f32(), Expr(15_f), Id(4000_a));
Inspector& inspector = Build();
@@ -1313,9 +1313,9 @@
}
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
- Override("v1", ty.f32(), Id(1));
- Override("v20", ty.f32(), Id(20));
- Override("v300", ty.f32(), Id(300));
+ Override("v1", ty.f32(), Id(1_a));
+ Override("v20", ty.f32(), Id(20_a));
+ Override("v300", ty.f32(), Id(300_a));
auto* a = Override("a", ty.f32());
auto* b = Override("b", ty.f32());
auto* c = Override("c", ty.f32());
diff --git a/src/tint/number.h b/src/tint/number.h
index a36ccfb..4635051 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -496,4 +496,19 @@
} // namespace tint::number_suffixes
+namespace std {
+
+/// Custom std::hash specialization for tint::Number<T>
+template <typename T>
+class hash<tint::Number<T>> {
+ public:
+ /// @param n the Number
+ /// @return the hash value
+ inline std::size_t operator()(const tint::Number<T>& n) const {
+ return std::hash<decltype(n.value)>()(n.value);
+ }
+};
+
+} // namespace std
+
#endif // SRC_TINT_NUMBER_H_
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 033577b..487420c 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2946,26 +2946,32 @@
/// @param id the id value
/// @returns the override attribute pointer
const ast::IdAttribute* Id(const Source& source, OverrideId id) {
- return create<ast::IdAttribute>(source, id.value);
+ return create<ast::IdAttribute>(source, Expr(AInt(id.value)));
}
/// Creates an ast::IdAttribute with an override identifier
/// @param id the optional id value
/// @returns the override attribute pointer
- const ast::IdAttribute* Id(OverrideId id) { return Id(source_, id); }
+ const ast::IdAttribute* Id(OverrideId id) {
+ return create<ast::IdAttribute>(Expr(AInt(id.value)));
+ }
/// Creates an ast::IdAttribute
/// @param source the source information
- /// @param id the id value
+ /// @param id the id value expression
/// @returns the override attribute pointer
- const ast::IdAttribute* Id(const Source& source, uint32_t id) {
- return create<ast::IdAttribute>(source, id);
+ template <typename EXPR>
+ const ast::IdAttribute* Id(const Source& source, EXPR&& id) {
+ return create<ast::IdAttribute>(source, Expr(std::forward<EXPR>(id)));
}
/// Creates an ast::IdAttribute with an override identifier
- /// @param id the optional id value
+ /// @param id the optional id value expression
/// @returns the override attribute pointer
- const ast::IdAttribute* Id(uint32_t id) { return Id(source_, id); }
+ template <typename EXPR>
+ const ast::IdAttribute* Id(EXPR&& id) {
+ return create<ast::IdAttribute>(Expr(std::forward<EXPR>(id)));
+ }
/// Creates an ast::StageAttribute
/// @param source the source information
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index 85caaf9..2942591 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -1366,7 +1366,7 @@
"between 0 and 65535: ID %"
<< inst.result_id() << " has SpecId " << id;
}
- auto* cid = create<ast::IdAttribute>(Source{}, id);
+ auto* cid = builder_.Id(Source{}, AInt(id));
spec_id_decos.Push(cid);
break;
}
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index aee4b99..d57ad78 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3508,7 +3508,9 @@
}
match(Token::Type::kComma);
- return create<ast::IdAttribute>(t.source(), val.value);
+ return create<ast::IdAttribute>(
+ t.source(), create<ast::IntLiteralExpression>(
+ val.value, ast::IntLiteralExpression::Suffix::kNone));
});
}
diff --git a/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc
index 46bfa34..2f1397f 100644
--- a/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc
@@ -201,7 +201,7 @@
auto* override_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_NE(override_attr, nullptr);
- EXPECT_EQ(override_attr->value, 7u);
+ EXPECT_TRUE(override_attr->value->Is<ast::IntLiteralExpression>());
}
TEST_F(ParserImplTest, GlobalOverrideDecl_WithId_TrailingComma) {
@@ -231,7 +231,7 @@
auto* override_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_NE(override_attr, nullptr);
- EXPECT_EQ(override_attr->value, 7u);
+ EXPECT_TRUE(override_attr->value->Is<ast::IntLiteralExpression>());
}
TEST_F(ParserImplTest, GlobalOverrideDecl_WithoutId) {
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index fd06200..0180589 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -97,7 +97,7 @@
case AttributeKind::kGroup:
return {builder.Group(source, 1_a)};
case AttributeKind::kId:
- return {builder.create<ast::IdAttribute>(source, 0u)};
+ return {builder.Id(source, 0_a)};
case AttributeKind::kInterpolate:
return {builder.Interpolate(source, ast::InterpolationType::kLinear,
ast::InterpolationSampling::kCenter)};
@@ -786,8 +786,8 @@
TEST_F(ConstantAttributeTest, DuplicateAttribute) {
GlobalConst("a", ty.f32(), Expr(1.23_f),
utils::Vector{
- create<ast::IdAttribute>(Source{{12, 34}}, 0u),
- create<ast::IdAttribute>(Source{{56, 78}}, 1u),
+ Id(Source{{12, 34}}, 0_a),
+ Id(Source{{56, 78}}, 1_a),
});
EXPECT_FALSE(r()->Resolve());
@@ -829,8 +829,8 @@
TEST_F(OverrideAttributeTest, DuplicateAttribute) {
Override("a", ty.f32(), Expr(1.23_f),
utils::Vector{
- create<ast::IdAttribute>(Source{{12, 34}}, 0u),
- create<ast::IdAttribute>(Source{{56, 78}}, 1u),
+ Id(Source{{12, 34}}, 0_a),
+ Id(Source{{56, 78}}, 1_a),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/bitcast_validation_test.cc b/src/tint/resolver/bitcast_validation_test.cc
index d341808..c6d4cb4 100644
--- a/src/tint/resolver/bitcast_validation_test.cc
+++ b/src/tint/resolver/bitcast_validation_test.cc
@@ -25,12 +25,12 @@
template <typename T>
static constexpr Type Create() {
return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
- builder::DataType<T>::Expr};
+ builder::DataType<T>::ExprFromDouble};
}
builder::ast_type_func_ptr ast;
builder::sem_type_func_ptr sem;
- builder::ast_expr_func_ptr expr;
+ builder::ast_expr_from_double_func_ptr expr;
};
static constexpr Type kNumericScalars[] = {
diff --git a/src/tint/resolver/call_test.cc b/src/tint/resolver/call_test.cc
index 37aaffa..e5b245c 100644
--- a/src/tint/resolver/call_test.cc
+++ b/src/tint/resolver/call_test.cc
@@ -57,13 +57,13 @@
using ResolverCallTest = ResolverTest;
struct Params {
- builder::ast_expr_func_ptr create_value;
+ builder::ast_expr_from_double_func_ptr create_value;
builder::ast_type_func_ptr create_type;
};
template <typename T>
constexpr Params ParamsFor() {
- return Params{DataType<T>::Expr, DataType<T>::AST};
+ return Params{DataType<T>::ExprFromDouble, DataType<T>::AST};
}
static constexpr Params all_param_types[] = {
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index d515ac4..23e621f 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -2982,6 +2982,32 @@
EXPECT_EQ(i2->ConstantValue()->As<u32>(), 2_u);
}
+TEST_F(ResolverConstEvalTest, Matrix_AFloat_Construct_From_AInt_Vectors) {
+ auto* c = Const("a", Construct(ty.mat(nullptr, 2, 2), //
+ Construct(ty.vec(nullptr, 2), Expr(1_a), Expr(2_a)),
+ Construct(ty.vec(nullptr, 2), Expr(3_a), Expr(4_a))));
+ WrapInFunction(c);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(c);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<sem::Matrix>());
+ auto* cv = sem->ConstantValue();
+ EXPECT_TYPE(cv->Type(), sem->Type());
+ EXPECT_TRUE(cv->Index(0)->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(cv->Index(0)->Index(0)->Type()->Is<sem::AbstractFloat>());
+ EXPECT_FALSE(cv->AllEqual());
+ EXPECT_FALSE(cv->AnyZero());
+ EXPECT_FALSE(cv->AllZero());
+ auto* c0 = cv->Index(0);
+ auto* c1 = cv->Index(1);
+ EXPECT_EQ(std::get<AFloat>(c0->Index(0)->Value()), 1.0);
+ EXPECT_EQ(std::get<AFloat>(c0->Index(1)->Value()), 2.0);
+ EXPECT_EQ(std::get<AFloat>(c1->Index(0)->Value()), 3.0);
+ EXPECT_EQ(std::get<AFloat>(c1->Index(1)->Value()), 4.0);
+}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// Unary op
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -3129,29 +3155,127 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
// Binary op
////////////////////////////////////////////////////////////////////////////////////////////////////
+
namespace binary_op {
-using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
+using builder::IsValue;
+using builder::Mat;
+using builder::Val;
+using builder::Value;
+using builder::Vec;
+
+using Types = std::variant<Value<AInt>,
+ Value<AFloat>,
+ Value<u32>,
+ Value<i32>,
+ Value<f32>,
+ Value<f16>,
+
+ Value<builder::vec2<AInt>>,
+ Value<builder::vec2<AFloat>>,
+ Value<builder::vec2<u32>>,
+ Value<builder::vec2<i32>>,
+ Value<builder::vec2<f32>>,
+ Value<builder::vec2<f16>>,
+
+ Value<builder::vec3<AInt>>,
+ Value<builder::vec3<AFloat>>,
+ Value<builder::vec3<u32>>,
+ Value<builder::vec3<i32>>,
+ Value<builder::vec3<f32>>,
+ Value<builder::vec3<f16>>,
+
+ Value<builder::vec4<AInt>>,
+ Value<builder::vec4<AFloat>>,
+ Value<builder::vec4<u32>>,
+ Value<builder::vec4<i32>>,
+ Value<builder::vec4<f32>>,
+ Value<builder::vec4<f16>>,
+
+ Value<builder::mat2x2<AInt>>,
+ Value<builder::mat2x2<AFloat>>,
+ Value<builder::mat2x2<f32>>,
+ Value<builder::mat2x2<f16>>,
+
+ Value<builder::mat2x3<AInt>>,
+ Value<builder::mat2x3<AFloat>>,
+ Value<builder::mat2x3<f32>>,
+ Value<builder::mat2x3<f16>>,
+
+ Value<builder::mat3x2<AInt>>,
+ Value<builder::mat3x2<AFloat>>,
+ Value<builder::mat3x2<f32>>,
+ Value<builder::mat3x2<f16>>
+ //
+ >;
struct Case {
Types lhs;
Types rhs;
Types expected;
- bool is_overflow;
+ bool overflow;
};
+/// Creates a Case with Values of any type
+template <typename T, typename U, typename V>
+Case C(Value<T> lhs, Value<U> rhs, Value<V> expected, bool overflow = false) {
+ return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow};
+}
+
+/// Convenience overload to creates a Case with just scalars
+template <typename T, typename U, typename V, typename = std::enable_if_t<!IsValue<T>>>
+Case C(T lhs, U rhs, V expected, bool overflow = false) {
+ return Case{Val(lhs), Val(rhs), Val(expected), overflow};
+}
+
static std::ostream& operator<<(std::ostream& o, const Case& c) {
- std::visit(
- [&](auto&& lhs, auto&& rhs, auto&& expected) {
- o << "lhs: " << lhs << ", rhs: " << rhs << ", expected: " << expected;
- },
- c.lhs, c.rhs, c.expected);
+ auto print_value = [&](auto&& value) {
+ std::visit(
+ [&](auto&& v) {
+ using ValueType = std::decay_t<decltype(v)>;
+ o << ValueType::DataType::Name() << "(";
+ for (auto& a : v.args.values) {
+ o << std::get<typename ValueType::ElementType>(a);
+ if (&a != &v.args.values.Back()) {
+ o << ", ";
+ }
+ }
+ o << ")";
+ },
+ value);
+ };
+ o << "lhs: ";
+ print_value(c.lhs);
+ o << ", rhs: ";
+ print_value(c.rhs);
+ o << ", expected: ";
+ print_value(c.expected);
+ o << ", overflow: " << c.overflow;
return o;
}
-template <typename T, typename U, typename V>
-Case C(T lhs, U rhs, V expected, bool is_overflow = false) {
- return Case{lhs, rhs, expected, is_overflow};
+// Calls `f` on deepest elements of both `a` and `b`. If function returns false, it stops
+// traversing, and return false, otherwise it continues and returns true.
+// TODO(amaiorano): Move to Constant.h?
+template <typename Func>
+bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
+ EXPECT_EQ(a->Type(), b->Type());
+ size_t i = 0;
+ while (true) {
+ auto* a_elem = a->Index(i);
+ if (!a_elem) {
+ break;
+ }
+ auto* b_elem = b->Index(i);
+ if (!ForEachElemPair(a_elem, b_elem, f)) {
+ return false;
+ }
+ i++;
+ }
+ if (i == 0) {
+ return f(a, b);
+ }
+ return true;
}
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
@@ -3159,35 +3283,51 @@
Enable(ast::Extension::kF16);
auto op = std::get<0>(GetParam());
- auto c = std::get<1>(GetParam());
- std::visit(
- [&](auto&& lhs, auto&& rhs, auto&& expected) {
- using T = std::decay_t<decltype(expected)>;
+ auto& c = std::get<1>(GetParam());
+ std::visit(
+ [&](auto&& expected) {
+ using T = typename std::decay_t<decltype(expected)>::ElementType;
if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
- if (c.is_overflow) {
+ if (c.overflow) {
+ // Overflow is not allowed for abstract types. This is tested separately.
return;
}
}
- auto* expr = create<ast::BinaryExpression>(op, Expr(lhs), Expr(rhs));
+ auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
+ auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
+ auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
+
GlobalConst("C", expr);
+ auto* expected_expr = expected.Expr(*this);
+ GlobalConst("E", expected_expr);
+
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
const sem::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
- EXPECT_EQ(value->As<T>(), expected);
- if constexpr (IsInteger<UnwrapNumber<T>>) {
- // Check that the constant's integer doesn't contain unexpected data in the MSBs
- // that are outside of the bit-width of T.
- EXPECT_EQ(value->As<AInt>(), AInt(expected));
- }
+ auto* expected_sem = Sem().Get(expected_expr);
+ const sem::Constant* expected_value = expected_sem->ConstantValue();
+ ASSERT_NE(expected_value, nullptr);
+ EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
+
+ ForEachElemPair(value, expected_value,
+ [&](const sem::Constant* a, const sem::Constant* b) {
+ EXPECT_EQ(a->As<T>(), b->As<T>());
+ if constexpr (IsInteger<UnwrapNumber<T>>) {
+ // Check that the constant's integer doesn't contain unexpected
+ // data in the MSBs that are outside of the bit-width of T.
+ EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
+ }
+ return !HasFailure();
+ });
},
- c.lhs, c.rhs, c.expected);
+ c.expected);
}
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
@@ -3299,33 +3439,37 @@
EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'");
}
-TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarScalar) {
- auto* a = Const("a", Expr(1_a)); // AInt
- auto* b = Const("b", Expr(2.3_a)); // AFloat
- auto* c = Add(Expr("a"), Expr("b"));
- WrapInFunction(a, b, c);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
- auto* sem = Sem().Get(c);
- ASSERT_TRUE(sem);
- ASSERT_TRUE(sem->ConstantValue());
- auto result = sem->ConstantValue()->As<AFloat>();
- EXPECT_EQ(result, 3.3f);
-}
-
-TEST_F(ResolverConstEvalTest, BinaryAbstractMixed_ScalarVector) {
- auto* a = Const("a", Expr(1_a)); // AInt
- auto* b = Const("b", Construct(ty.vec(nullptr, 3), Expr(2.3_a))); // AFloat
- auto* c = Add(Expr("a"), Expr("b"));
- WrapInFunction(a, b, c);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
- auto* sem = Sem().Get(c);
- ASSERT_TRUE(sem);
- ASSERT_TRUE(sem->ConstantValue());
- EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 3.3f);
- EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 3.3f);
- EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 3.3f);
-}
-
+// Mixed AInt and AFloat args to test implicit conversion to AFloat
+INSTANTIATE_TEST_SUITE_P(
+ AbstractMixed,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine(
+ testing::Values(ast::BinaryOp::kAdd),
+ testing::Values(C(Val(1_a), Val(2.3_a), Val(3.3_a)),
+ C(Val(2.3_a), Val(1_a), Val(3.3_a)),
+ C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)),
+ C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)),
+ C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)),
+ C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)),
+ C(Mat({1_a, 2_a}, //
+ {1_a, 2_a}, //
+ {1_a, 2_a}), //
+ Mat({1.2_a, 2.3_a}, //
+ {1.2_a, 2.3_a}, //
+ {1.2_a, 2.3_a}), //
+ Mat({2.2_a, 4.3_a}, //
+ {2.2_a, 4.3_a}, //
+ {2.2_a, 4.3_a})), //
+ C(Mat({1.2_a, 2.3_a}, //
+ {1.2_a, 2.3_a}, //
+ {1.2_a, 2.3_a}), //
+ Mat({1_a, 2_a}, //
+ {1_a, 2_a}, //
+ {1_a, 2_a}), //
+ Mat({2.2_a, 4.3_a}, //
+ {2.2_a, 4.3_a}, //
+ {2.2_a, 4.3_a})) //
+ )));
} // namespace binary_op
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
index 8ce611e..83ca0bd 100644
--- a/src/tint/resolver/inferred_type_test.cc
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -43,13 +43,15 @@
struct ResolverInferredTypeTest : public resolver::TestHelper, public testing::Test {};
struct Params {
- builder::ast_expr_func_ptr create_value;
+ // builder::ast_expr_func_ptr_default_arg create_value;
+ builder::ast_expr_from_double_func_ptr create_value;
builder::sem_type_func_ptr create_expected_type;
};
template <typename T>
constexpr Params ParamsFor() {
- return Params{DataType<T>::Expr, DataType<T>::Sem};
+ // return Params{builder::CreateExprWithDefaultArg<T>(), DataType<T>::Sem};
+ return Params{DataType<T>::ExprFromDouble, DataType<T>::Sem};
}
Params all_cases[] = {
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 01ff0d5..239efd5 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -255,9 +255,9 @@
std::string target_element_type_name;
builder::ast_type_func_ptr target_ast_ty;
builder::sem_type_func_ptr target_sem_ty;
- builder::ast_expr_func_ptr target_expr;
+ builder::ast_expr_from_double_func_ptr target_expr;
std::string abstract_type_name;
- builder::ast_expr_func_ptr abstract_expr;
+ builder::ast_expr_from_double_func_ptr abstract_expr;
std::variant<AInt, AFloat> materialized_value;
double literal_value;
};
@@ -268,13 +268,13 @@
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
- TargetDataType::Name(), // target_type_name
- TargetElementDataType::Name(), // target_element_type_name
- TargetDataType::AST, // target_ast_ty
- TargetDataType::Sem, // target_sem_ty
- TargetDataType::Expr, // target_expr
- AbstractDataType::Name(), // abstract_type_name
- AbstractDataType::Expr, // abstract_expr
+ TargetDataType::Name(), // target_type_name
+ TargetElementDataType::Name(), // target_element_type_name
+ TargetDataType::AST, // target_ast_ty
+ TargetDataType::Sem, // target_sem_ty
+ TargetDataType::ExprFromDouble, // target_expr
+ AbstractDataType::Name(), // abstract_type_name
+ AbstractDataType::ExprFromDouble, // abstract_expr
materialized_value,
literal_value,
};
@@ -286,13 +286,13 @@
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
- TargetDataType::Name(), // target_type_name
- TargetElementDataType::Name(), // target_element_type_name
- TargetDataType::AST, // target_ast_ty
- TargetDataType::Sem, // target_sem_ty
- TargetDataType::Expr, // target_expr
- AbstractDataType::Name(), // abstract_type_name
- AbstractDataType::Expr, // abstract_expr
+ TargetDataType::Name(), // target_type_name
+ TargetElementDataType::Name(), // target_element_type_name
+ TargetDataType::AST, // target_ast_ty
+ TargetDataType::Sem, // target_sem_ty
+ TargetDataType::ExprFromDouble, // target_expr
+ AbstractDataType::Name(), // abstract_type_name
+ AbstractDataType::ExprFromDouble, // abstract_expr
0_a,
0.0,
};
@@ -826,7 +826,7 @@
std::string expected_element_type_name;
builder::sem_type_func_ptr expected_sem_ty;
std::string abstract_type_name;
- builder::ast_expr_func_ptr abstract_expr;
+ builder::ast_expr_from_double_func_ptr abstract_expr;
std::variant<AInt, AFloat> materialized_value;
double literal_value;
};
@@ -837,11 +837,11 @@
using AbstractDataType = builder::DataType<ABSTRACT_TYPE>;
using TargetElementDataType = builder::DataType<typename ExpectedDataType::ElementType>;
return {
- ExpectedDataType::Name(), // expected_type_name
- TargetElementDataType::Name(), // expected_element_type_name
- ExpectedDataType::Sem, // expected_sem_ty
- AbstractDataType::Name(), // abstract_type_name
- AbstractDataType::Expr, // abstract_expr
+ ExpectedDataType::Name(), // expected_type_name
+ TargetElementDataType::Name(), // expected_element_type_name
+ ExpectedDataType::Sem, // expected_sem_ty
+ AbstractDataType::Name(), // abstract_type_name
+ AbstractDataType::ExprFromDouble, // abstract_expr
materialized_value,
literal_value,
};
diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc
index 08fec58..4fdbcee 100644
--- a/src/tint/resolver/override_test.cc
+++ b/src/tint/resolver/override_test.cc
@@ -50,7 +50,7 @@
}
TEST_F(ResolverOverrideTest, WithId) {
- auto* a = Override("a", ty.f32(), Expr(1_f), Id(7u));
+ auto* a = Override("a", ty.f32(), Expr(1_f), Id(7_u));
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -69,10 +69,10 @@
std::vector<ast::Variable*> variables;
auto* a = Override("a", ty.f32(), Expr(1_f));
auto* b = Override("b", ty.f32(), Expr(1_f));
- auto* c = Override("c", ty.f32(), Expr(1_f), Id(2u));
- auto* d = Override("d", ty.f32(), Expr(1_f), Id(4u));
+ auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u));
+ auto* d = Override("d", ty.f32(), Expr(1_f), Id(4_u));
auto* e = Override("e", ty.f32(), Expr(1_f));
- auto* f = Override("f", ty.f32(), Expr(1_f), Id(1u));
+ auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u));
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -86,8 +86,8 @@
}
TEST_F(ResolverOverrideTest, DuplicateIds) {
- Override("a", ty.f32(), Expr(1_f), Id(Source{{12, 34}}, 7u));
- Override("b", ty.f32(), Expr(1_f), Id(Source{{56, 78}}, 7u));
+ Override("a", ty.f32(), Expr(1_f), Id(Source{{12, 34}}, 7_u));
+ Override("b", ty.f32(), Expr(1_f), Id(Source{{56, 78}}, 7_u));
EXPECT_FALSE(r()->Resolve());
@@ -96,7 +96,7 @@
}
TEST_F(ResolverOverrideTest, IdTooLarge) {
- Override("a", ty.f32(), Expr(1_f), Id(Source{{12, 34}}, 65536u));
+ Override("a", ty.f32(), Expr(1_f), Id(Source{{12, 34}}, 65536_u));
EXPECT_FALSE(r()->Resolve());
@@ -106,7 +106,7 @@
TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
Enable(ast::Extension::kF16);
- Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1u));
+ Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1_u));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 527c921..1599334 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -23,6 +23,7 @@
#include "src/tint/ast/alias.h"
#include "src/tint/ast/array.h"
#include "src/tint/ast/assignment_statement.h"
+#include "src/tint/ast/attribute.h"
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/call_statement.h"
@@ -437,12 +438,35 @@
auto* sem = builder_->create<sem::GlobalVariable>(
v, ty, sem::EvaluationStage::kOverride, ast::StorageClass::kNone, ast::Access::kUndefined,
/* constant_value */ nullptr, sem::BindingPoint{});
+ sem->SetConstructor(rhs);
- if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
- sem->SetOverrideId(OverrideId{static_cast<decltype(OverrideId::value)>(id->value)});
+ if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
+ auto* materialize = Materialize(Expression(id_attr->value));
+ if (!materialize) {
+ return nullptr;
+ }
+ auto* c = materialize->ConstantValue();
+ if (!c) {
+ // TODO(crbug.com/tint/1633): Handle invalid materialization when expressions
+ // are supported.
+ return nullptr;
+ }
+
+ auto value = c->As<uint32_t>();
+ if (value > std::numeric_limits<decltype(OverrideId::value)>::max()) {
+ AddError("override IDs must be between 0 and " +
+ std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
+ id_attr->source);
+ return nullptr;
+ }
+
+ auto o = OverrideId{static_cast<decltype(OverrideId::value)>(value)};
+ sem->SetOverrideId(o);
+
+ // Track the constant IDs that are specified in the shader.
+ override_ids_.emplace(o, sem);
}
- sem->SetConstructor(rhs);
builder_->Sem().Add(v, sem);
return sem;
}
@@ -737,8 +761,8 @@
}
OverrideId id;
- if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) {
- id = OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)};
+ if (ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
+ id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
} else {
// No ID was specified, so allocate the next available ID.
while (!ids_exhausted && override_ids_.count(next_id)) {
@@ -777,12 +801,6 @@
for (auto* attr : v->attributes) {
Mark(attr);
-
- if (auto* id_attr = attr->As<ast::IdAttribute>()) {
- // Track the constant IDs that are specified in the shader.
- override_ids_.emplace(
- OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)}, sem);
- }
}
if (!validator_.NoDuplicateAttributes(v->attributes)) {
@@ -1672,10 +1690,12 @@
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
if (stage == sem::EvaluationStage::kConstant) {
- auto const_args =
- utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
+ auto const_args = ConvertArguments(args, ctor_or_conv.target);
+ if (!const_args) {
+ return nullptr;
+ }
if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)(
- ctor_or_conv.target->ReturnType(), const_args, expr->source)) {
+ ctor_or_conv.target->ReturnType(), const_args.Get(), expr->source)) {
value = r.Get();
} else {
return nullptr;
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 595973d..72d741e 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -1001,9 +1001,9 @@
// @id(2) override depth = 2i;
// @compute @workgroup_size(width, height, depth)
// fn main() {}
- auto* width = Override("width", ty.i32(), Expr(16_i), Id(0));
- auto* height = Override("height", ty.i32(), Expr(8_i), Id(1));
- auto* depth = Override("depth", ty.i32(), Expr(2_i), Id(2));
+ auto* width = Override("width", ty.i32(), Expr(16_i), Id(0_a));
+ auto* height = Override("height", ty.i32(), Expr(8_i), Id(1_a));
+ auto* depth = Override("depth", ty.i32(), Expr(2_i), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
@@ -1029,9 +1029,9 @@
// @id(2) override depth : i32;
// @compute @workgroup_size(width, height, depth)
// fn main() {}
- auto* width = Override("width", ty.i32(), Id(0));
- auto* height = Override("height", ty.i32(), Id(1));
- auto* depth = Override("depth", ty.i32(), Id(2));
+ auto* width = Override("width", ty.i32(), Id(0_a));
+ auto* height = Override("height", ty.i32(), Id(1_a));
+ auto* depth = Override("depth", ty.i32(), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
@@ -1056,7 +1056,7 @@
// const depth = 3i;
// @compute @workgroup_size(8, height, depth)
// fn main() {}
- auto* height = Override("height", ty.i32(), Expr(2_i), Id(0));
+ auto* height = Override("height", ty.i32(), Expr(2_i), Id(0_a));
GlobalConst("depth", ty.i32(), Expr(3_i));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index bab1d84..6641176 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -18,6 +18,9 @@
#include <functional>
#include <memory>
#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
#include "gtest/gtest.h"
#include "src/tint/program_builder.h"
@@ -170,8 +173,35 @@
template <typename TO>
struct ptr {};
+/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
+/// composite types, or all values requried by the composite type.
+struct ScalarArgs {
+ /// Constructor
+ /// @param single_value single value to initialize with
+ template <typename T>
+ ScalarArgs(T single_value) // NOLINT: implicit on purpose
+ : values(utils::Vector<Storage, 1>{single_value}) {}
+
+ /// Constructor
+ /// @param all_values all values to initialize the composite type with
+ template <typename T>
+ ScalarArgs(utils::VectorRef<T> all_values) // NOLINT: implicit on purpose
+ {
+ for (auto& v : all_values) {
+ values.Push(v);
+ }
+ }
+
+ /// Valid scalar types for args
+ using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
+
+ /// The vector of values
+ utils::Vector<Storage, 16> values;
+};
+
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
-using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double elem_value);
+using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
+using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
template <typename T>
@@ -202,10 +232,16 @@
/// @return the semantic bool type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::Bool>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the b
+ /// @param args args of size 1 with the boolean value to init with
/// @return a new AST expression of the bool type
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(std::equal_to<double>()(elem_value, 0));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<bool>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to bool.
+ /// @return a new AST expression of the bool type
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@@ -227,10 +263,16 @@
/// @return the semantic i32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::I32>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value i32 will be initialized with
+ /// @param args args of size 1 with the i32 value to init with
/// @return a new AST i32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(static_cast<i32>(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<i32>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to i32.
+ /// @return a new AST i32 literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; }
@@ -252,10 +294,16 @@
/// @return the semantic u32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::U32>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value u32 will be initialized with
+ /// @param args args of size 1 with the u32 value to init with
/// @return a new AST u32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(static_cast<u32>(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<u32>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to u32.
+ /// @return a new AST u32 literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; }
@@ -277,10 +325,16 @@
/// @return the semantic f32 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::F32>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value f32 will be initialized with
+ /// @param args args of size 1 with the f32 value to init with
/// @return a new AST f32 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(static_cast<f32>(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<f32>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to f32.
+ /// @return a new AST f32 literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<f32>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; }
@@ -302,10 +356,16 @@
/// @return the semantic f16 type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::F16>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value f16 will be initialized with
+ /// @param args args of size 1 with the f16 value to init with
/// @return a new AST f16 literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(static_cast<f16>(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<f16>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to f16.
+ /// @return a new AST f16 literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; }
@@ -326,10 +386,16 @@
/// @return the semantic abstract-float type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::AbstractFloat>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value the abstract-float literal will be constructed with
+ /// @param args args of size 1 with the abstract-float value to init with
/// @return a new AST abstract-float literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(AFloat(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<AFloat>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to AFloat.
+ /// @return a new AST abstract-float literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; }
@@ -350,10 +416,16 @@
/// @return the semantic abstract-int type
static inline const sem::Type* Sem(ProgramBuilder& b) { return b.create<sem::AbstractInt>(); }
/// @param b the ProgramBuilder
- /// @param elem_value the value the abstract-int literal will be constructed with
+ /// @param args args of size 1 with the abstract-int value to init with
/// @return a new AST abstract-int literal value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Expr(AInt(elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Expr(std::get<AInt>(args.values[0]));
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to AInt.
+ /// @return a new AST abstract-int literal value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; }
@@ -379,22 +451,27 @@
return b.create<sem::Vector>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
- /// @param elem_value the value each element in the vector will be initialized
- /// with
+ /// @param args args of size 1 or N with values of type T to initialize with
/// @return a new AST vector value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Construct(AST(b), ExprArgs(b, elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
-
/// @param b the ProgramBuilder
- /// @param elem_value the value each element will be initialized with
+ /// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the vector
- static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
- utils::Vector<const ast::Expression*, N> args;
- for (uint32_t i = 0; i < N; i++) {
- args.Push(DataType<T>::Expr(b, elem_value));
+ static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
+ const bool one_value = args.values.Length() == 1;
+ utils::Vector<const ast::Expression*, N> r;
+ for (size_t i = 0; i < N; ++i) {
+ r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
}
- return args;
+ return r;
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to ElementType
+ /// @return a new AST vector value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -423,22 +500,36 @@
return b.create<sem::Matrix>(column_type, N);
}
/// @param b the ProgramBuilder
- /// @param elem_value the value each element in the matrix will be initialized
- /// with
+ /// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Construct(AST(b), ExprArgs(b, elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
-
/// @param b the ProgramBuilder
- /// @param elem_value the value each element will be initialized with
- /// @return the list of expressions that are used to construct the matrix
- static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
- utils::Vector<const ast::Expression*, N> args;
- for (uint32_t i = 0; i < N; i++) {
- args.Push(DataType<vec<M, T>>::Expr(b, elem_value));
+ /// @param args args of size 1 or N*M with values of type T to initialize with
+ /// @return a new AST matrix value expression
+ static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
+ const bool one_value = args.values.Length() == 1;
+ size_t next = 0;
+ utils::Vector<const ast::Expression*, N> r;
+ for (uint32_t i = 0; i < N; ++i) {
+ if (one_value) {
+ r.Push(DataType<vec<M, T>>::Expr(b, args.values[0]));
+ } else {
+ utils::Vector<T, M> v;
+ for (size_t j = 0; j < M; ++j) {
+ v.Push(std::get<T>(args.values[next++]));
+ }
+ r.Push(DataType<vec<M, T>>::Expr(b, utils::VectorRef<T>{v}));
+ }
}
- return args;
+ return r;
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to ElementType
+ /// @return a new AST matrix value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -451,7 +542,7 @@
template <typename T, int ID>
struct DataType<alias<T, ID>> {
/// The element type
- using ElementType = T;
+ using ElementType = typename DataType<T>::ElementType;
/// true if the aliased type is a composite type
static constexpr bool is_composite = DataType<T>::is_composite;
@@ -471,24 +562,32 @@
static inline const sem::Type* Sem(ProgramBuilder& b) { return DataType<T>::Sem(b); }
/// @param b the ProgramBuilder
- /// @param elem_value the value nested elements will be initialized with
+ /// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- double elem_value) {
+ ScalarArgs args) {
// Cast
- return b.Construct(AST(b), DataType<T>::Expr(b, elem_value));
+ return b.Construct(AST(b), DataType<T>::Expr(b, std::move(args)));
}
/// @param b the ProgramBuilder
- /// @param elem_value the value nested elements will be initialized with
+ /// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
- double elem_value) {
+ ScalarArgs args) {
// Construct
- return b.Construct(AST(b), DataType<T>::ExprArgs(b, elem_value));
+ return b.Construct(AST(b), DataType<T>::ExprArgs(b, std::move(args)));
}
+
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to ElementType
+ /// @return a new AST expression of the alias type
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
+ }
+
/// @returns the WGSL name for the type
static inline std::string Name() { return "alias_" + std::to_string(ID); }
};
@@ -497,7 +596,7 @@
template <typename T>
struct DataType<ptr<T>> {
/// The element type
- using ElementType = T;
+ using ElementType = typename DataType<T>::ElementType;
/// true if the pointer type is a composite type
static constexpr bool is_composite = false;
@@ -516,12 +615,20 @@
}
/// @param b the ProgramBuilder
- /// @return a new AST expression of the alias type
- static inline const ast::Expression* Expr(ProgramBuilder& b, double /*unused*/) {
+ /// @return a new AST expression of the pointer type
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.GlobalVar(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
return b.AddressOf(sym);
}
+
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to ElementType
+ /// @return a new AST expression of the pointer type
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
+ }
+
/// @returns the WGSL name for the type
static inline std::string Name() { return "ptr<" + DataType<T>::Name() + ">"; }
};
@@ -530,7 +637,7 @@
template <uint32_t N, typename T>
struct DataType<array<N, T>> {
/// The element type
- using ElementType = T;
+ using ElementType = typename DataType<T>::ElementType;
/// true as arrays are a composite type
static constexpr bool is_composite = true;
@@ -556,22 +663,28 @@
/* implicit_stride */ el->Align());
}
/// @param b the ProgramBuilder
- /// @param elem_value the value each element in the array will be initialized
+ /// @param args args of size 1 or N with values of type T to initialize with
/// with
/// @return a new AST array value expression
- static inline const ast::Expression* Expr(ProgramBuilder& b, double elem_value) {
- return b.Construct(AST(b), ExprArgs(b, elem_value));
+ static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
+ return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
-
/// @param b the ProgramBuilder
- /// @param elem_value the value each element will be initialized with
+ /// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the array
- static inline auto ExprArgs(ProgramBuilder& b, double elem_value) {
- utils::Vector<const ast::Expression*, N> args;
+ static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
+ const bool one_value = args.values.Length() == 1;
+ utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) {
- args.Push(DataType<T>::Expr(b, elem_value));
+ r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i]));
}
- return args;
+ return r;
+ }
+ /// @param b the ProgramBuilder
+ /// @param v arg of type double that will be cast to ElementType
+ /// @return a new AST array value expression
+ static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
+ return Expr(b, static_cast<ElementType>(v));
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@@ -585,6 +698,8 @@
ast_type_func_ptr ast;
/// ast expression type create function
ast_expr_func_ptr expr;
+ /// ast expression type create function from double arg
+ ast_expr_from_double_func_ptr expr_from_double;
/// sem type create function
sem_type_func_ptr sem;
};
@@ -593,11 +708,129 @@
/// type `T`
template <typename T>
constexpr CreatePtrs CreatePtrsFor() {
- return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
+ return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem};
+}
+
+/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
+/// expressions with.
+template <typename T>
+struct Value {
+ /// Alias to T
+ using Type = T;
+ /// Alias to DataType<T>
+ using DataType = builder::DataType<T>;
+ /// Alias to DataType::ElementType
+ using ElementType = typename DataType::ElementType;
+
+ /// Creates a Value<T> with `args`
+ /// @param args the args that will be passed to the expression
+ /// @returns a Value<T>
+ static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; }
+
+ /// Creates an `ast::Expression` for the type T passing in previously stored args
+ /// @param b the ProgramBuilder
+ /// @returns an expression node
+ const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); }
+
+ /// ast expression type create function
+ ast_expr_func_ptr expr;
+ /// args to create expression with
+ ScalarArgs args;
+};
+
+namespace detail {
+/// Base template for IsValue
+template <typename T>
+struct IsValue : std::false_type {};
+/// Specialization for IsValue
+template <typename T>
+struct IsValue<Value<T>> : std::true_type {};
+} // namespace detail
+
+/// True if T is of type Value
+template <typename T>
+constexpr bool IsValue = detail::IsValue<T>::value;
+
+/// Creates a `Value<T>` from a scalar `v`
+template <typename T>
+auto Val(T v) {
+ return Value<T>::Create(v);
+}
+
+/// Creates a `Value<vec<N, T>>` from N scalar `args`
+template <typename... T>
+auto Vec(T&&... args) {
+ constexpr size_t N = sizeof...(args);
+ using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
+ utils::Vector v{args...};
+ using VT = vec<N, FirstT>;
+ return Value<VT>::Create(utils::VectorRef<FirstT>{v});
+}
+
+/// Creates a `Value<mat<C,R,T>` from C*R scalar `args`
+template <size_t C, size_t R, typename T>
+auto Mat(const T (&m_in)[C][R]) {
+ utils::Vector<T, C * R> m;
+ for (uint32_t i = 0; i < C; ++i) {
+ for (size_t j = 0; j < R; ++j) {
+ m.Push(m_in[i][j]);
+ }
+ }
+ return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
+}
+
+/// Creates a `Value<mat<2,R,T>` from column vectors `c0` and `c1`
+template <typename T, size_t R>
+auto Mat(const T (&c0)[R], const T (&c1)[R]) {
+ constexpr size_t C = 2;
+ utils::Vector<T, C * R> m;
+ for (auto v : c0) {
+ m.Push(v);
+ }
+ for (auto v : c1) {
+ m.Push(v);
+ }
+ return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
+}
+
+/// Creates a `Value<mat<3,R,T>` from column vectors `c0`, `c1`, and `c2`
+template <typename T, size_t R>
+auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
+ constexpr size_t C = 3;
+ utils::Vector<T, C * R> m;
+ for (auto v : c0) {
+ m.Push(v);
+ }
+ for (auto v : c1) {
+ m.Push(v);
+ }
+ for (auto v : c2) {
+ m.Push(v);
+ }
+ return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
+}
+
+/// Creates a `Value<mat<4,R,T>` from column vectors `c0`, `c1`, `c2`, and `c3`
+template <typename T, size_t R>
+auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
+ constexpr size_t C = 4;
+ utils::Vector<T, C * R> m;
+ for (auto v : c0) {
+ m.Push(v);
+ }
+ for (auto v : c1) {
+ m.Push(v);
+ }
+ for (auto v : c2) {
+ m.Push(v);
+ }
+ for (auto v : c3) {
+ m.Push(v);
+ }
+ return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
}
} // namespace builder
-
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_RESOLVER_TEST_HELPER_H_
diff --git a/src/tint/resolver/type_constructor_validation_test.cc b/src/tint/resolver/type_constructor_validation_test.cc
index 6d771f2..de64e9a 100644
--- a/src/tint/resolver/type_constructor_validation_test.cc
+++ b/src/tint/resolver/type_constructor_validation_test.cc
@@ -47,13 +47,13 @@
namespace InferTypeTest {
struct Params {
builder::ast_type_func_ptr create_rhs_ast_type;
- builder::ast_expr_func_ptr create_rhs_ast_value;
+ builder::ast_expr_from_double_func_ptr create_rhs_ast_value;
builder::sem_type_func_ptr create_rhs_sem_type;
};
template <typename T>
constexpr Params ParamsFor() {
- return Params{DataType<T>::AST, DataType<T>::Expr, DataType<T>::Sem};
+ return Params{DataType<T>::AST, DataType<T>::ExprFromDouble, DataType<T>::Sem};
}
TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) {
@@ -242,12 +242,13 @@
Kind kind;
builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type;
- builder::ast_expr_func_ptr rhs_value_expr;
+ builder::ast_expr_from_double_func_ptr rhs_value_expr;
};
template <typename LhsType, typename RhsType>
constexpr Params ParamsFor(Kind kind) {
- return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST, DataType<RhsType>::Expr};
+ return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST,
+ DataType<RhsType>::ExprFromDouble};
}
static constexpr Params valid_cases[] = {
@@ -426,7 +427,7 @@
// Skip test for valid cases
for (auto& v : valid_cases) {
if (v.lhs_type == lhs_params.ast && v.rhs_type == rhs_params.ast &&
- v.rhs_value_expr == rhs_params.expr) {
+ v.rhs_value_expr == rhs_params.expr_from_double) {
return;
}
}
@@ -439,7 +440,7 @@
auto* lhs_type1 = lhs_params.ast(*this);
auto* lhs_type2 = lhs_params.ast(*this);
auto* rhs_type = rhs_params.ast(*this);
- auto* rhs_value_expr = rhs_params.expr(*this, 0);
+ auto* rhs_value_expr = rhs_params.expr_from_double(*this, 0);
std::stringstream ss;
ss << FriendlyName(lhs_type1) << " = " << FriendlyName(lhs_type2) << "("
@@ -2437,7 +2438,7 @@
uint32_t columns;
name_func_ptr get_element_type_name;
builder::ast_type_func_ptr create_element_ast_type;
- builder::ast_expr_func_ptr create_element_ast_value;
+ builder::ast_expr_from_double_func_ptr create_element_ast_value;
builder::ast_type_func_ptr create_column_ast_type;
builder::ast_type_func_ptr create_mat_ast_type;
};
@@ -2449,7 +2450,7 @@
C,
DataType<T>::Name,
DataType<T>::AST,
- DataType<T>::Expr,
+ DataType<T>::ExprFromDouble,
DataType<tint::resolver::builder::vec<R, T>>::AST,
DataType<tint::resolver::builder::mat<C, R, T>>::AST,
};
@@ -3058,7 +3059,7 @@
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
if (i < N - 1) {
- auto* ctor_value_expr = str_params.expr(*this, 0);
+ auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
}
@@ -3084,7 +3085,7 @@
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
}
- auto* ctor_value_expr = str_params.expr(*this, 0);
+ auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
auto* s = Structure("s", members);
@@ -3122,8 +3123,8 @@
auto* struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
auto* ctor_value_expr = (i == constructor_value_with_different_type)
- ? ctor_params.expr(*this, 0)
- : str_params.expr(*this, 0);
+ ? ctor_params.expr_from_double(*this, 0)
+ : str_params.expr_from_double(*this, 0);
values.Push(ctor_value_expr);
}
auto* s = Structure("s", members);
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index 39ce5e3..c5b7220 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -75,7 +75,7 @@
TEST_F(ResolverTypeValidationTest, GlobalOverrideNoConstructor_Pass) {
// @id(0) override a :i32;
- Override(Source{{12, 34}}, "a", ty.i32(), Id(0));
+ Override(Source{{12, 34}}, "a", ty.i32(), Id(0_u));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 08451f0..caca32e 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -776,7 +776,7 @@
}
bool Validator::Override(
- const sem::Variable* v,
+ const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const {
auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef();
@@ -788,20 +788,11 @@
}
for (auto* attr : decl->attributes) {
- if (auto* id_attr = attr->As<ast::IdAttribute>()) {
- uint32_t id = id_attr->value;
- if (id > std::numeric_limits<decltype(OverrideId::value)>::max()) {
- AddError(
- "override IDs must be between 0 and " +
- std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
- attr->source);
- return false;
- }
- if (auto it =
- override_ids.find(OverrideId{static_cast<decltype(OverrideId::value)>(id)});
- it != override_ids.end() && it->second != v) {
+ if (attr->Is<ast::IdAttribute>()) {
+ auto id = v->OverrideId();
+ if (auto it = override_ids.find(id); it != override_ids.end() && it->second != v) {
AddError("override IDs must be unique", attr->source);
- AddNote("a override with an ID of " + std::to_string(id) +
+ AddNote("a override with an ID of " + std::to_string(id.value) +
" was previously declared here:",
ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
->source);
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index 57ac064..8bec86f 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -376,7 +376,7 @@
/// @param v the variable to validate
/// @param override_id the set of override ids in the module
/// @returns true on success, false otherwise.
- bool Override(const sem::Variable* v,
+ bool Override(const sem::GlobalVariable* v,
const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const;
/// Validates a 'const' variable declaration
diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc
index 3416909..9843d2f 100644
--- a/src/tint/resolver/variable_validation_test.cc
+++ b/src/tint/resolver/variable_validation_test.cc
@@ -98,7 +98,7 @@
// ...
// @id(N) override oN : i32;
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
- Override("reserved", ty.i32(), Id(kLimit));
+ Override("reserved", ty.i32(), Id(AInt(kLimit)));
for (size_t i = 0; i < kLimit; i++) {
Override("o" + std::to_string(i), ty.i32());
}
diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc
index 109904c..6663666 100644
--- a/src/tint/transform/array_length_from_uniform_test.cc
+++ b/src/tint/transform/array_length_from_uniform_test.cc
@@ -124,8 +124,9 @@
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
- EXPECT_EQ(std::unordered_set<uint32_t>({0}),
- got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
+ auto* val = got.data.Get<ArrayLengthFromUniform::Result>();
+ ASSERT_NE(val, nullptr);
+ EXPECT_EQ(std::unordered_set<uint32_t>({0}), val->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, BasicInStruct) {
diff --git a/src/tint/transform/vectorize_matrix_conversions.cc b/src/tint/transform/vectorize_matrix_conversions.cc
new file mode 100644
index 0000000..576b885
--- /dev/null
+++ b/src/tint/transform/vectorize_matrix_conversions.cc
@@ -0,0 +1,136 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/vectorize_matrix_conversions.h"
+
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/abstract_numeric.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/type_conversion.h"
+#include "src/tint/utils/hash.h"
+#include "src/tint/utils/map.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions);
+
+namespace tint::transform {
+
+VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
+
+VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
+
+bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
+ if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
+ if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<sem::Matrix>()) {
+ auto& args = call->Arguments();
+ if (args.Length() == 1 && args[0]->Type()->UnwrapRef()->is_float_matrix()) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
+void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ using HelperFunctionKey =
+ utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
+
+ std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
+
+ ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
+ auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* ty_conv = call->Target()->As<sem::TypeConversion>();
+ if (!ty_conv) {
+ return nullptr;
+ }
+ auto* dst_type = call->Type()->As<sem::Matrix>();
+ if (!dst_type) {
+ return nullptr;
+ }
+
+ auto& args = call->Arguments();
+ if (args.Length() != 1) {
+ return nullptr;
+ }
+
+ auto& src = args[0];
+
+ auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>();
+ if (!src_type) {
+ return nullptr;
+ }
+
+ // The source and destination type of a matrix conversion must have a same shape.
+ if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "source and destination matrix has different shape in matrix conversion";
+ return nullptr;
+ }
+
+ auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
+ utils::Vector<const ast::Expression*, 4> columns;
+ for (uint32_t c = 0; c < dst_type->columns(); c++) {
+ auto* src_matrix_expr = src_expression_builder();
+ auto* src_column_expr =
+ ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c)));
+ columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()),
+ src_column_expr));
+ }
+ return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns);
+ };
+
+ // Replace the matrix conversion to column vector conversions and a matrix construction.
+ if (!src->HasSideEffects()) {
+ // Simply use the argument's declaration if it has no side effects.
+ return build_vectorized_conversion_expression([&]() { //
+ return ctx.Clone(src->Declaration());
+ });
+ } else {
+ // If has side effects, use a helper function.
+ auto fn =
+ utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
+ auto name =
+ ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) +
+ "x" + std::to_string(src_type->rows()) + "_" +
+ ctx.dst->FriendlyName(src_type->type()) + "_" +
+ ctx.dst->FriendlyName(dst_type->type()));
+ ctx.dst->Func(
+ name,
+ utils::Vector{
+ ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)),
+ },
+ CreateASTTypeFor(ctx, dst_type),
+ utils::Vector{
+ ctx.dst->Return(build_vectorized_conversion_expression([&]() { //
+ return ctx.dst->Expr("value");
+ })),
+ });
+ return name;
+ });
+ return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
+ }
+ });
+
+ ctx.Clone();
+}
+
+} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_matrix_conversions.h b/src/tint/transform/vectorize_matrix_conversions.h
new file mode 100644
index 0000000..f16467c
--- /dev/null
+++ b/src/tint/transform/vectorize_matrix_conversions.h
@@ -0,0 +1,48 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
+#define SRC_TINT_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
+
+#include "src/tint/transform/transform.h"
+
+namespace tint::transform {
+
+/// A transform that converts matrix conversions (between f32 and f16 matrices) to the vector form.
+class VectorizeMatrixConversions final : public Castable<VectorizeMatrixConversions, Transform> {
+ public:
+ /// Constructor
+ VectorizeMatrixConversions();
+
+ /// Destructor
+ ~VectorizeMatrixConversions() override;
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+};
+
+} // namespace tint::transform
+
+#endif // SRC_TINT_TRANSFORM_VECTORIZE_MATRIX_CONVERSIONS_H_
diff --git a/src/tint/transform/vectorize_matrix_conversions_test.cc b/src/tint/transform/vectorize_matrix_conversions_test.cc
new file mode 100644
index 0000000..8142b58
--- /dev/null
+++ b/src/tint/transform/vectorize_matrix_conversions_test.cc
@@ -0,0 +1,411 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/vectorize_matrix_conversions.h"
+
+#include <string>
+#include <utility>
+
+#include "src/tint/transform/test_helper.h"
+#include "src/tint/utils/string.h"
+
+namespace tint::transform {
+namespace {
+
+using VectorizeMatrixConversionsTest = TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
+
+TEST_F(VectorizeMatrixConversionsTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+}
+
+// Test that VectorizeMatrixConversions transforms the matRxC<f32> to matRxC<f16> conversion as
+// expected.
+//
+// Example input:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f32>(vec2<f32>(0.0, 1.0), vec2<f32>(2.0, 3.0), vec2<f32>(4.0, 5.0));
+// let n : mat3x2<f16> = mat3x2<f16>(m);
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f32>(vec2<f32>(0.0, 1.0), vec2<f32>(2.0, 3.0), vec2<f32>(4.0, 5.0));
+// let n : mat3x2<f16> = mat3x2<f16>(vec2<f16>(m[0]), vec2<f16>(m[1]), vec2<f16>(m[2]));
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_F32ToF16) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string src_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string src_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string dst_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f16>";
+ std::string dst_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ }
+ vector_values += src_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vectorized_args += ", ";
+ }
+ vectorized_args += dst_vec_type + "(m[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+@fragment
+fn main() {
+ let m = ${src_mat_type}(${values});
+ let n : ${dst_mat_type} = ${dst_mat_type}(${args});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${src_mat_type}", src_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${dst_mat_type}", dst_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${values}", vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${args}", "m");
+ auto expect = utils::ReplaceAll(tmpl, "${args}", vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transforms the matRxC<f32> to matRxC<f16> conversion as
+// expected.
+//
+// Example input:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f16>(vec2<f16>(0.0, 1.0), vec2<f16>(2.0, 3.0), vec2<f16>(4.0, 5.0));
+// let n : mat3x2<f32> = mat3x2<f32>(m);
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// @fragment
+// fn main() {
+// let m = mat3x2<f16>(vec2<f16>(0.0, 1.0), vec2<f16>(2.0, 3.0), vec2<f16>(4.0, 5.0));
+// let n : mat3x2<f32> = mat3x2<f32>(vec2<f32>(m[0]), vec2<f32>(m[1]), vec2<f32>(m[2]));
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_F16ToF32) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string src_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f16>";
+ std::string src_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string dst_mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string dst_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vector_values += ", ";
+ }
+ vector_values += src_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ vector_values += ", ";
+ }
+ auto value = std::to_string(c * rows + r) + ".0";
+ vector_values += value;
+ }
+ vector_values += ")";
+ }
+
+ std::string vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ vectorized_args += ", ";
+ }
+ vectorized_args += dst_vec_type + "(m[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+@fragment
+fn main() {
+ let m = ${src_mat_type}(${values});
+ let n : ${dst_mat_type} = ${dst_mat_type}(${args});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${src_mat_type}", src_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${dst_mat_type}", dst_mat_type);
+ tmpl = utils::ReplaceAll(tmpl, "${values}", vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${args}", "m");
+ auto expect = utils::ReplaceAll(tmpl, "${args}", vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform generates help functions for conversions of which
+// input expression has side effect.
+//
+// Example input:
+//
+// enable f16;
+//
+// var<private> i : i32 = 0;
+//
+// fn mat_f32() -> mat2x2<f32> {
+// i = (i + 1);
+// return mat2x2<f32>(vec2<f32>(f32(i), f32(i)), vec2<f32>(f32(i), f32(i)));
+// }
+//
+// fn mat_f16() -> mat2x2<f16> {
+// i = (i + 1);
+// return mat2x2<f16>(vec2<f16>(f16(i), f16(i)), vec2<f16>(f16(i), f16(i)));
+// }
+//
+// @fragment
+// fn main() {
+// let m32 : mat2x2<f32> = mat2x2<f32>(mat_f16());
+// let m16 : mat2x2<f16> = mat2x2<f16>(mat_f32());
+// }
+//
+// Example output:
+//
+// enable f16;
+//
+// var<private> i : i32 = 0;
+//
+// fn mat_f32() -> mat2x2<f32> {
+// i = (i + 1);
+// return mat2x2<f32>(vec2<f32>(f32(i), f32(i)), vec2<f32>(f32(i), f32(i)));
+// }
+//
+// fn mat_f16() -> mat2x2<f16> {
+// i = (i + 1);
+// return mat2x2<f16>(vec2<f16>(f16(i), f16(i)), vec2<f16>(f16(i), f16(i)));
+// }
+//
+// fn convert_mat2x2_f16_f32(value : mat2x2<f16>) -> mat2x2<f32> {
+// return mat2x2<f32>(vec2<f32>(value[0]), vec2<f32>(value[1]));
+// }
+//
+// fn convert_mat2x2_f32_f16(value : mat2x2<f32>) -> mat2x2<f16> {
+// return mat2x2<f16>(vec2<f16>(value[0]), vec2<f16>(value[1]));
+// }
+//
+// @fragment
+// fn main() {
+// let m32 : mat2x2<f32> = convert_mat2x2_f16_f32(mat_f16());
+// let m16 : mat2x2<f16> = convert_mat2x2_f32_f16(mat_f32());
+// }
+TEST_P(VectorizeMatrixConversionsTest, Conversion_WithSideEffect) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_shape = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
+ std::string f32_mat_type = mat_shape + "<f32>";
+ std::string f32_vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string f16_mat_type = mat_shape + "<f16>";
+ std::string f16_vec_type = "vec" + std::to_string(rows) + "<f16>";
+ std::string f32_vector_values;
+ std::string f16_vector_values;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ f32_vector_values += ", ";
+ f16_vector_values += ", ";
+ }
+ f32_vector_values += f32_vec_type + "(";
+ f16_vector_values += f16_vec_type + "(";
+ for (uint32_t r = 0; r < rows; r++) {
+ if (r > 0) {
+ f32_vector_values += ", ";
+ f16_vector_values += ", ";
+ }
+ f32_vector_values += "f32(i)";
+ f16_vector_values += "f16(i)";
+ }
+ f32_vector_values += ")";
+ f16_vector_values += ")";
+ }
+
+ std::string f32_vectorized_args = "";
+ std::string f16_vectorized_args = "";
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ f32_vectorized_args += ", ";
+ f16_vectorized_args += ", ";
+ }
+ f32_vectorized_args += f32_vec_type + "(value[" + std::to_string(c) + "])";
+ f16_vectorized_args += f16_vec_type + "(value[" + std::to_string(c) + "])";
+ }
+
+ std::string tmpl = R"(
+enable f16;
+
+var<private> i : i32 = 0;
+
+fn mat_f32() -> ${f32_mat_type} {
+ i = (i + 1);
+ return ${f32_mat_type}(${f32_values});
+}
+
+fn mat_f16() -> ${f16_mat_type} {
+ i = (i + 1);
+ return ${f16_mat_type}(${f16_values});
+}
+${helper_function}
+@fragment
+fn main() {
+ let m32 : ${f32_mat_type} = ${f32_matrix_conversion};
+ let m16 : ${f16_mat_type} = ${f16_matrix_conversion};
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${f32_values}", f32_vector_values);
+ tmpl = utils::ReplaceAll(tmpl, "${f16_values}", f16_vector_values);
+ auto src = utils::ReplaceAll(tmpl, "${f32_matrix_conversion}", "${f32_mat_type}(mat_f16())");
+ src = utils::ReplaceAll(src, "${f16_matrix_conversion}", "${f16_mat_type}(mat_f32())");
+ src = utils::ReplaceAll(src, "${helper_function}", "");
+ src = utils::ReplaceAll(src, "${f32_mat_type}", f32_mat_type);
+ src = utils::ReplaceAll(src, "${f16_mat_type}", f16_mat_type);
+
+ auto helper_function = std::string(R"(
+fn convert_${mat_shape}_f16_f32(value : ${f16_mat_type}) -> ${f32_mat_type} {
+ return ${f32_mat_type}(${f32_vectorized_args});
+}
+
+fn convert_${mat_shape}_f32_f16(value : ${f32_mat_type}) -> ${f16_mat_type} {
+ return ${f16_mat_type}(${f16_vectorized_args});
+}
+)");
+ auto expect = utils::ReplaceAll(tmpl, "${helper_function}", helper_function);
+ expect = utils::ReplaceAll(expect, "${f32_mat_type}", f32_mat_type);
+ expect = utils::ReplaceAll(expect, "${f16_mat_type}", f16_mat_type);
+ expect = utils::ReplaceAll(expect, "${f32_matrix_conversion}",
+ "convert_${mat_shape}_f16_f32(mat_f16())");
+ expect = utils::ReplaceAll(expect, "${f16_matrix_conversion}",
+ "convert_${mat_shape}_f32_f16(mat_f32())");
+ expect = utils::ReplaceAll(expect, "${mat_shape}", mat_shape);
+ expect = utils::ReplaceAll(expect, "${f32_vectorized_args}", f32_vectorized_args);
+ expect = utils::ReplaceAll(expect, "${f16_vectorized_args}", f16_vectorized_args);
+
+ EXPECT_TRUE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform will not run for matrix constructor.
+TEST_P(VectorizeMatrixConversionsTest, NonConversion_ConstructorFromVectors) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${columns});
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+// Test that VectorizeMatrixConversions transform will not run for identity matrix constructor,
+// which also take a single matrix as input.
+TEST_P(VectorizeMatrixConversionsTest, NonConversion_IdentityConstructor) {
+ uint32_t cols = GetParam().first;
+ uint32_t rows = GetParam().second;
+ std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
+ std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
+ std::string columns;
+ for (uint32_t c = 0; c < cols; c++) {
+ if (c > 0) {
+ columns += ", ";
+ }
+ columns += vec_type + "()";
+ }
+
+ std::string tmpl = R"(
+@fragment
+fn main() {
+ let m = ${matrix}(${columns});
+ let n : ${matrix} = ${matrix}(m);
+}
+)";
+ tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
+ auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
+ auto expect = src;
+
+ EXPECT_FALSE(ShouldRun<VectorizeMatrixConversions>(src));
+
+ auto got = Run<VectorizeMatrixConversions>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+INSTANTIATE_TEST_SUITE_P(VectorizeMatrixConversionsTest,
+ VectorizeMatrixConversionsTest,
+ testing::Values(std::make_pair(2, 2),
+ std::make_pair(2, 3),
+ std::make_pair(2, 4),
+ std::make_pair(3, 2),
+ std::make_pair(3, 3),
+ std::make_pair(3, 4),
+ std::make_pair(4, 2),
+ std::make_pair(4, 3),
+ std::make_pair(4, 4)));
+
+} // namespace
+} // namespace tint::transform
diff --git a/src/tint/utils/hash.h b/src/tint/utils/hash.h
index ad53841..717b35f 100644
--- a/src/tint/utils/hash.h
+++ b/src/tint/utils/hash.h
@@ -20,6 +20,7 @@
#include <functional>
#include <tuple>
#include <utility>
+#include <variant>
#include <vector>
#include "src/tint/utils/vector.h"
@@ -117,6 +118,16 @@
}
};
+/// Hasher specialization for std::tuple
+template <typename... TYPES>
+struct Hasher<std::variant<TYPES...>> {
+ /// @param variant the variant to hash
+ /// @returns a hash of the tuple
+ size_t operator()(const std::variant<TYPES...>& variant) const {
+ return std::visit([](auto&& val) { return Hash(val); }, variant);
+ }
+};
+
/// @returns a hash of the variadic list of arguments.
/// The returned hash is dependent on the order of the arguments.
template <typename... ARGS>
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
index e179389..87a0dd9 100644
--- a/src/tint/utils/hashmap.h
+++ b/src/tint/utils/hashmap.h
@@ -35,28 +35,6 @@
typename HASH = Hasher<K>,
typename EQUAL = std::equal_to<K>>
class Hashmap {
- /// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into
- /// the underlying Hashset.
- ///
- /// LazyCreator holds a #key, and a #create function used to build the final Entry::value.
- /// The #create function must be of the signature `V()`.
- ///
- /// LazyCreator can be compared to Entry and hashed, allowing them to be passed to
- /// Hashset::Insert(). If the set does not contain an existing entry with #key,
- /// Hashset::Insert() will construct a new Entry passing the rvalue LazyCreator as the
- /// constructor argument, which in turn calls the #create function to generate the entry value.
- ///
- /// @see Entry
- /// @see Hasher
- /// @see Equality
- template <typename CREATE>
- struct LazyCreator {
- /// The key of the entry to insert into the map
- const K& key;
- /// The value creation function
- CREATE create;
- };
-
/// Entry holds a key and value pair, and is used as the element type of the underlying Hashset.
/// Entries are compared and hashed using only the #key.
/// @see Hasher
@@ -71,23 +49,6 @@
/// Move-constructor.
Entry(Entry&&) = default;
- /// Constructor from a LazyCreator.
- /// The constructor invokes the LazyCreator::create function to build the #value.
- /// @see LazyCreator
- template <typename CREATE>
- Entry(const LazyCreator<CREATE>& creator) // NOLINT(runtime/explicit)
- : key(creator.key), value(creator.create()) {}
-
- /// Assignment operator from a LazyCreator.
- /// The assignment invokes the LazyCreator::create function to build the #value.
- /// @see LazyCreator
- template <typename CREATE>
- Entry& operator=(LazyCreator<CREATE>&& creator) {
- key = std::move(creator.key);
- value = creator.create();
- return *this;
- }
-
/// Copy-assignment operator
Entry& operator=(const Entry&) = default;
@@ -99,33 +60,23 @@
};
/// Hash provider for the underlying Hashset.
- /// Provides hash functions for an Entry, K or LazyCreator.
+ /// Provides hash functions for an Entry or K.
/// The hash functions only consider the key of an entry.
struct Hasher {
/// Calculates a hash from an Entry
size_t operator()(const Entry& entry) const { return HASH()(entry.key); }
/// Calculates a hash from a K
size_t operator()(const K& key) const { return HASH()(key); }
- /// Calculates a hash from a LazyCreator
- template <typename CREATE>
- size_t operator()(const LazyCreator<CREATE>& lc) const {
- return HASH()(lc.key);
- }
};
/// Equality provider for the underlying Hashset.
- /// Provides equality functions for an Entry, K or LazyCreator to an Entry.
+ /// Provides equality functions for an Entry or K to an Entry.
/// The equality functions only consider the key for equality.
struct Equality {
/// Compares an Entry to an Entry for equality.
bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); }
/// Compares a K to an Entry for equality.
bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); }
- /// Compares a LazyCreator to an Entry for equality.
- template <typename CREATE>
- bool operator()(const LazyCreator<CREATE>& lc, const Entry& b) const {
- return EQUAL()(lc.key, b.key);
- }
};
/// The underlying set
@@ -151,7 +102,8 @@
/// Used by gmock for the `ElementsAre` checks.
using value_type = KeyValue;
- /// Iterator for the map
+ /// Iterator for the map.
+ /// Iterators are invalidated if the map is modified.
class Iterator {
public:
/// @returns the key of the entry pointed to by this iterator
@@ -226,13 +178,29 @@
/// Searches for an entry with the given key value, adding and returning the result of
/// calling `create` if the entry was not found.
+ /// @note: Before calling `create`, the map will insert a zero-initialized value for the given
+ /// key, which will be replaced with the value returned by `create`. If `create` adds an entry
+ /// with `key` to this map, it will be replaced.
/// @param key the entry's key value to search for.
/// @param create the create function to call if the map does not contain the key.
/// @returns the value of the entry.
template <typename CREATE>
V& GetOrCreate(const K& key, CREATE&& create) {
- LazyCreator<CREATE> lc{key, std::forward<CREATE>(create)};
- auto res = set_.Add(std::move(lc));
+ auto res = set_.Add(Entry{key, V{}});
+ if (res.action == AddAction::kAdded) {
+ // Store the set generation before calling create()
+ auto generation = set_.Generation();
+ // Call create(), which might modify this map.
+ auto value = create();
+ // Was this map mutated?
+ if (set_.Generation() == generation) {
+ // Calling create() did not touch the map. No need to lookup again.
+ res.entry->value = std::move(value);
+ } else {
+ // Calling create() modified the map. Need to insert again.
+ res = set_.Replace(Entry{key, std::move(value)});
+ }
+ }
return res.entry->value;
}
@@ -241,9 +209,7 @@
/// @param key the entry's key value to search for.
/// @returns the value of the entry.
V& GetOrZero(const K& key) {
- auto zero = [] { return V{}; };
- LazyCreator<decltype(zero)> lc{key, zero};
- auto res = set_.Add(std::move(lc));
+ auto res = set_.Add(Entry{key, V{}});
return res.entry->value;
}
@@ -288,6 +254,9 @@
/// @returns the number of entries in the map.
size_t Count() const { return set_.Count(); }
+ /// @returns a monotonic counter which is incremented whenever the map is mutated.
+ size_t Generation() const { return set_.Generation(); }
+
/// @returns true if the map contains no entries.
bool IsEmpty() const { return set_.IsEmpty(); }
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
index 45e929b..9a5b01e 100644
--- a/src/tint/utils/hashmap_test.cc
+++ b/src/tint/utils/hashmap_test.cc
@@ -69,6 +69,27 @@
EXPECT_FALSE(map.Contains("world"));
}
+TEST(Hashmap, Generation) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.Generation(), 0u);
+ map.Add(1, "one");
+ EXPECT_EQ(map.Generation(), 1u);
+ map.Add(1, "uno");
+ EXPECT_EQ(map.Generation(), 1u);
+ map.Replace(1, "une");
+ EXPECT_EQ(map.Generation(), 2u);
+ map.Add(2, "dos");
+ EXPECT_EQ(map.Generation(), 3u);
+ map.Remove(1);
+ EXPECT_EQ(map.Generation(), 4u);
+ map.Clear();
+ EXPECT_EQ(map.Generation(), 5u);
+ map.Find(2);
+ EXPECT_EQ(map.Generation(), 5u);
+ map.Get(2);
+ EXPECT_EQ(map.Generation(), 5u);
+}
+
TEST(Hashmap, Iterator) {
using Map = Hashmap<int, std::string, 8>;
using KV = typename Map::KeyValue;
@@ -98,9 +119,16 @@
TEST(Hashmap, GetOrCreate) {
Hashmap<int, std::string, 8> map;
- EXPECT_EQ(map.GetOrCreate(0, [&] { return "zero"; }), "zero");
+ std::optional<std::string> value_of_key_0_at_create;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ value_of_key_0_at_create = map.Get(0);
+ return "zero";
+ }),
+ "zero");
EXPECT_EQ(map.Count(), 1u);
EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(value_of_key_0_at_create, "");
bool create_called = false;
EXPECT_EQ(map.GetOrCreate(0,
@@ -118,6 +146,67 @@
EXPECT_EQ(map.Get(1), "one");
}
+TEST(Hashmap, GetOrCreate_CreateModifiesMap) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ map.Add(3, "three");
+ map.Add(1, "one");
+ map.Add(2, "two");
+ return "zero";
+ }),
+ "zero");
+ EXPECT_EQ(map.Count(), 4u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+
+ bool create_called = false;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ create_called = true;
+ return "oh noes";
+ }),
+ "zero");
+ EXPECT_FALSE(create_called);
+ EXPECT_EQ(map.Count(), 4u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+
+ EXPECT_EQ(map.GetOrCreate(4,
+ [&] {
+ map.Add(6, "six");
+ map.Add(5, "five");
+ map.Add(7, "seven");
+ return "four";
+ }),
+ "four");
+ EXPECT_EQ(map.Count(), 8u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+ EXPECT_EQ(map.Get(4), "four");
+ EXPECT_EQ(map.Get(5), "five");
+ EXPECT_EQ(map.Get(6), "six");
+ EXPECT_EQ(map.Get(7), "seven");
+}
+
+TEST(Hashmap, GetOrCreate_CreateAddsSameKeyedValue) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.GetOrCreate(42,
+ [&] {
+ map.Add(42, "should-be-replaced");
+ return "expected-value";
+ }),
+ "expected-value");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_EQ(map.Get(42), "expected-value");
+}
+
TEST(Hashmap, Soak) {
std::mt19937 rnd;
std::unordered_map<std::string, std::string> reference;
diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h
index 3009c72..f7d5efe 100644
--- a/src/tint/utils/hashset.h
+++ b/src/tint/utils/hashset.h
@@ -73,7 +73,8 @@
static constexpr size_t kMinSlots = std::max<size_t>(kNumFixedSlots, 4);
public:
- /// Iterator for entries in the set
+ /// Iterator for entries in the set.
+ /// Iterators are invalidated if the set is modified.
class Iterator {
public:
/// @returns the value pointed to by this iterator
@@ -152,6 +153,7 @@
slots_.Clear(); // Destructs all entries
slots_.Resize(kMinSlots);
count_ = 0;
+ generation_++;
}
/// Result of Add()
@@ -219,6 +221,7 @@
// Entry was removed.
count_--;
+ generation_++;
return true;
}
@@ -299,6 +302,9 @@
/// @returns true if the set contains no entries.
bool IsEmpty() const { return count_ == 0; }
+ /// @returns a monotonic counter which is incremented whenever the set is mutated.
+ size_t Generation() const { return generation_; }
+
/// @returns an iterator to the start of the set.
Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; }
@@ -351,6 +357,7 @@
slot.hash = hash.value;
slot.distance = distance;
count_++;
+ generation_++;
result = AddResult{AddAction::kAdded, &slot.value.value()};
return Action::kStop;
}
@@ -361,6 +368,7 @@
// Slot is equal to value. Replace or preserve?
if constexpr (MODE == PutMode::kReplace) {
slot.value = std::forward<V>(value);
+ generation_++;
result = AddResult{AddAction::kReplaced, &slot.value.value()};
} else {
result = AddResult{AddAction::kKeptExisting, &slot.value.value()};
@@ -380,6 +388,7 @@
InsertShuffle(Wrap(index + 1), std::move(evicted));
count_++;
+ generation_++;
result = AddResult{AddAction::kAdded, &slot.value.value()};
return Action::kStop;
@@ -502,6 +511,9 @@
/// The number of entries in the set.
size_t count_ = 0;
+
+ /// Counter that's incremented with each modification to the set.
+ size_t generation_ = 0;
};
} // namespace tint::utils
diff --git a/src/tint/utils/hashset_test.cc b/src/tint/utils/hashset_test.cc
index 4213b32..6e8d1cc 100644
--- a/src/tint/utils/hashset_test.cc
+++ b/src/tint/utils/hashset_test.cc
@@ -67,6 +67,27 @@
}
}
+TEST(Hashset, Generation) {
+ Hashset<int, 8> set;
+ EXPECT_EQ(set.Generation(), 0u);
+ set.Add(1);
+ EXPECT_EQ(set.Generation(), 1u);
+ set.Add(1);
+ EXPECT_EQ(set.Generation(), 1u);
+ set.Replace(1);
+ EXPECT_EQ(set.Generation(), 2u);
+ set.Add(2);
+ EXPECT_EQ(set.Generation(), 3u);
+ set.Remove(1);
+ EXPECT_EQ(set.Generation(), 4u);
+ set.Clear();
+ EXPECT_EQ(set.Generation(), 5u);
+ set.Find(2);
+ EXPECT_EQ(set.Generation(), 5u);
+ set.Get(2);
+ EXPECT_EQ(set.Generation(), 5u);
+}
+
TEST(Hashset, Iterator) {
Hashset<std::string, 8> set;
set.Add("one");
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index b718840..f0cbf58 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -243,11 +243,11 @@
/// Move constructor from a mutable vector reference
/// @param other the vector reference to move
- explicit Vector(VectorRef<T>&& other) { MoveOrCopy(std::move(other)); }
+ Vector(VectorRef<T>&& other) { MoveOrCopy(std::move(other)); } // NOLINT(runtime/explicit)
/// Copy constructor from an immutable vector reference
/// @param other the vector reference to copy
- explicit Vector(const VectorRef<T>& other) { Copy(other.slice_); }
+ Vector(const VectorRef<T>& other) { Copy(other.slice_); } // NOLINT(runtime/explicit)
/// Destructor
~Vector() { ClearAndFree(); }
@@ -475,6 +475,12 @@
return true;
}
+ /// Inequality operator
+ /// @param other the other vector
+ /// @returns true if this vector is not the same length as `other`, or all elements are not
+ /// equal.
+ bool operator!=(const Vector& other) const { return !(*this == other); }
+
private:
/// Friend class (differing specializations of this class)
template <typename, size_t>
diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc
index 24747a9..83ba37b 100644
--- a/src/tint/writer/glsl/generator_impl_function_test.cc
+++ b/src/tint/writer/glsl/generator_impl_function_test.cc
@@ -804,9 +804,9 @@
TEST_F(GlslGeneratorImplTest_Function,
Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9u));
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
Func("main", utils::Empty, ty.void_(), {},
utils::Vector{
Stage(ast::PipelineStage::kCompute),
diff --git a/src/tint/writer/glsl/generator_impl_module_constant_test.cc b/src/tint/writer/glsl/generator_impl_module_constant_test.cc
index e607efc..26dbf2e 100644
--- a/src/tint/writer/glsl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/glsl/generator_impl_module_constant_test.cc
@@ -346,7 +346,7 @@
}
TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23));
+ auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
GeneratorImpl& gen = Build();
@@ -359,7 +359,7 @@
}
TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
- auto* var = Override("pos", ty.f32(), Id(23));
+ auto* var = Override("pos", ty.f32(), Id(23_a));
GeneratorImpl& gen = Build();
@@ -372,7 +372,7 @@
}
TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
- auto* a = Override("a", ty.f32(), Expr(3_f), Id(0));
+ auto* a = Override("a", ty.f32(), Expr(3_f), Id(0_a));
auto* b = Override("b", ty.f32(), Expr(2_f));
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc
index c1124a5..14e8a70 100644
--- a/src/tint/writer/hlsl/generator_impl_function_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_function_test.cc
@@ -714,9 +714,9 @@
TEST_F(HlslGeneratorImplTest_Function,
Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9u));
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
diff --git a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
index 49cf831..58fdbf1 100644
--- a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
@@ -243,7 +243,7 @@
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23));
+ auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
GeneratorImpl& gen = Build();
@@ -256,7 +256,7 @@
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
- auto* var = Override("pos", ty.f32(), Id(23));
+ auto* var = Override("pos", ty.f32(), Id(23_a));
GeneratorImpl& gen = Build();
@@ -269,7 +269,7 @@
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
- auto* a = Override("a", ty.f32(), Expr(3_f), Id(0));
+ auto* a = Override("a", ty.f32(), Expr(3_f), Id(0_a));
auto* b = Override("b", ty.f32(), Expr(2_f));
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/msl/generator_impl_module_constant_test.cc b/src/tint/writer/msl/generator_impl_module_constant_test.cc
index fd7b5d6..9f665fd 100644
--- a/src/tint/writer/msl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/msl/generator_impl_module_constant_test.cc
@@ -329,7 +329,7 @@
}
TEST_F(MslGeneratorImplTest, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23));
+ auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
GeneratorImpl& gen = Build();
@@ -338,7 +338,7 @@
}
TEST_F(MslGeneratorImplTest, Emit_Override_NoId) {
- auto* var_a = Override("a", ty.f32(), Id(0));
+ auto* var_a = Override("a", ty.f32(), Id(0_a));
auto* var_b = Override("b", ty.f32());
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 55b8bc7..c2056a4 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -1390,8 +1390,7 @@
auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_mat = value_type->As<sem::Matrix>()) {
// Generate passthrough for matrices of the same type
- can_cast_or_copy =
- (res_mat->columns() == val_mat->columns()) && (res_mat->rows() == val_mat->rows());
+ can_cast_or_copy = res_mat == val_mat;
}
}
@@ -1578,13 +1577,19 @@
} else if ((from_type->is_float_scalar() && to_type->Is<sem::U32>()) ||
(from_type->is_float_vector() && to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU;
- } else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
- (from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
- (from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
- (from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
- (from_type->Is<sem::F16>() && to_type->Is<sem::F16>()) ||
- (from_type->Is<sem::Vector>() && (from_type == to_type))) {
+ } else if (from_type
+ ->IsAnyOf<sem::Bool, sem::F32, sem::I32, sem::U32, sem::F16, sem::Vector>() &&
+ from_type == to_type) {
+ // Identity constructor for scalar and vector types
return val_id;
+ } else if ((from_type->is_float_scalar() && to_type->is_float_scalar()) ||
+ (from_type->is_float_vector() && to_type->is_float_vector() &&
+ from_type->As<sem::Vector>()->Width() == to_type->As<sem::Vector>()->Width())) {
+ // Convert between f32 and f16 types.
+ // OpFConvert requires the scalar component types to be different, and the case of from_type
+ // and to_type being the same floating point scalar or vector type, i.e. identity
+ // constructor, is already handled in the previous else-if clause.
+ op = spv::Op::OpFConvert;
} else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
(from_type->Is<sem::U32>() && to_type->Is<sem::I32>()) ||
(from_type->is_signed_integer_vector() && to_type->is_unsigned_integer_vector()) ||
@@ -1644,8 +1649,18 @@
}
return result_id;
- } else if (from_type->Is<sem::Matrix>()) {
- return val_id;
+ } else if (from_type->Is<sem::Matrix>() && to_type->Is<sem::Matrix>()) {
+ // SPIRV does not support matrix conversion, the only valid case is matrix identity
+ // constructor. Matrix conversion between f32 and f16 should be transformed into vector
+ // conversions for each column vectors by VectorizeMatrixConversions.
+ auto* from_mat = from_type->As<sem::Matrix>();
+ auto* to_mat = to_type->As<sem::Matrix>();
+ if (from_mat == to_mat) {
+ return val_id;
+ }
+ TINT_ICE(Writer, builder_.Diagnostics())
+ << "matrix conversion is not supported and should have been handled by "
+ "VectorizeMatrixConversions";
} else {
TINT_ICE(Writer, builder_.Diagnostics()) << "Invalid from_type";
}
diff --git a/src/tint/writer/spirv/builder_constructor_expression_test.cc b/src/tint/writer/spirv/builder_constructor_expression_test.cc
index 9c9bd74..cb8dd4f 100644
--- a/src/tint/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/tint/writer/spirv/builder_constructor_expression_test.cc
@@ -3927,30 +3927,6 @@
)");
}
-TEST_F(SpvBuilderConstructorTest, Type_Convert_I32_To_U32) {
- auto* var = Decl(Var("x", ty.i32(), Expr(2_i)));
- auto* cast = Construct<u32>("x");
- WrapInFunction(var, cast);
-
- spirv::Builder& b = Build();
-
- b.push_function(Function{});
- EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
- EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
-
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
-%2 = OpConstant %1 2
-%4 = OpTypePointer Function %1
-%5 = OpConstantNull %1
-%7 = OpTypeInt 32 0
-)");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
- R"(OpStore %3 %2
-%8 = OpLoad %1 %3
-%6 = OpBitcast %7 %8
-)");
-}
-
TEST_F(SpvBuilderConstructorTest, Type_Convert_F32_To_I32) {
auto* var = Decl(Var("x", ty.f32(), Expr(2.4_f)));
auto* cast = Construct<i32>("x");
@@ -4001,6 +3977,30 @@
)");
}
+TEST_F(SpvBuilderConstructorTest, Type_Convert_I32_To_U32) {
+ auto* var = Decl(Var("x", ty.i32(), Expr(2_i)));
+ auto* cast = Construct<u32>("x");
+ WrapInFunction(var, cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
+%2 = OpConstant %1 2
+%4 = OpTypePointer Function %1
+%5 = OpConstantNull %1
+%7 = OpTypeInt 32 0
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %3 %2
+%8 = OpLoad %1 %3
+%6 = OpBitcast %7 %8
+)");
+}
+
TEST_F(SpvBuilderConstructorTest, Type_Convert_F32_To_U32) {
auto* var = Decl(Var("x", ty.f32(), Expr(2.4_f)));
auto* cast = Construct<u32>("x");
@@ -4075,6 +4075,56 @@
)");
}
+TEST_F(SpvBuilderConstructorTest, Type_Convert_U32_To_F32) {
+ auto* var = Decl(Var("x", ty.u32(), Expr(2_u)));
+ auto* cast = Construct<f32>("x");
+ WrapInFunction(var, cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
+%2 = OpConstant %1 2
+%4 = OpTypePointer Function %1
+%5 = OpConstantNull %1
+%7 = OpTypeFloat 32
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %3 %2
+%8 = OpLoad %1 %3
+%6 = OpConvertUToF %7 %8
+)");
+}
+
+TEST_F(SpvBuilderConstructorTest, Type_Convert_F16_To_F32) {
+ Enable(ast::Extension::kF16);
+
+ auto* var = Decl(Var("x", ty.f16(), Expr(2_h)));
+ auto* cast = Construct<f32>("x");
+ WrapInFunction(var, cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
+%2 = OpConstant %1 0x1p+1
+%4 = OpTypePointer Function %1
+%5 = OpConstantNull %1
+%7 = OpTypeFloat 32
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %3 %2
+%8 = OpLoad %1 %3
+%6 = OpFConvert %7 %8
+)");
+}
+
TEST_F(SpvBuilderConstructorTest, Type_Convert_I32_To_F16) {
Enable(ast::Extension::kF16);
@@ -4101,30 +4151,6 @@
)");
}
-TEST_F(SpvBuilderConstructorTest, Type_Convert_U32_To_F32) {
- auto* var = Decl(Var("x", ty.u32(), Expr(2_u)));
- auto* cast = Construct<f32>("x");
- WrapInFunction(var, cast);
-
- spirv::Builder& b = Build();
-
- b.push_function(Function{});
- EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
- EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
-
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
-%2 = OpConstant %1 2
-%4 = OpTypePointer Function %1
-%5 = OpConstantNull %1
-%7 = OpTypeFloat 32
-)");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
- R"(OpStore %3 %2
-%8 = OpLoad %1 %3
-%6 = OpConvertUToF %7 %8
-)");
-}
-
TEST_F(SpvBuilderConstructorTest, Type_Convert_U32_To_F16) {
Enable(ast::Extension::kF16);
@@ -4151,6 +4177,32 @@
)");
}
+TEST_F(SpvBuilderConstructorTest, Type_Convert_F32_To_F16) {
+ Enable(ast::Extension::kF16);
+
+ auto* var = Decl(Var("x", ty.f32(), Expr(2_f)));
+ auto* cast = Construct<f16>("x");
+ WrapInFunction(var, cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ EXPECT_TRUE(b.GenerateStatement(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpConstant %1 2
+%4 = OpTypePointer Function %1
+%5 = OpConstantNull %1
+%7 = OpTypeFloat 16
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpStore %3 %2
+%8 = OpLoad %1 %3
+%6 = OpFConvert %7 %8
+)");
+}
+
TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_U32_to_I32) {
auto* var = GlobalVar("i", ty.vec3<u32>(), ast::StorageClass::kPrivate);
@@ -4337,6 +4389,60 @@
)");
}
+TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_U32_to_F32) {
+ auto* var = GlobalVar("i", ty.vec3<u32>(), ast::StorageClass::kPrivate);
+
+ auto* cast = vec3<f32>("i");
+ WrapInFunction(cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpConvertUToF %7 %9
+)");
+}
+
+TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_F16_to_F32) {
+ Enable(ast::Extension::kF16);
+
+ auto* var = GlobalVar("i", ty.vec3<f16>(), ast::StorageClass::kPrivate);
+
+ auto* cast = vec3<f32>("i");
+ WrapInFunction(cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 16
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 32
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpFConvert %7 %9
+)");
+}
+
TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_I32_to_F16) {
Enable(ast::Extension::kF16);
@@ -4365,32 +4471,6 @@
)");
}
-TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_U32_to_F32) {
- auto* var = GlobalVar("i", ty.vec3<u32>(), ast::StorageClass::kPrivate);
-
- auto* cast = vec3<f32>("i");
- WrapInFunction(cast);
-
- spirv::Builder& b = Build();
-
- b.push_function(Function{});
- ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
- EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
-
- EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
-%3 = OpTypeVector %4 3
-%2 = OpTypePointer Private %3
-%5 = OpConstantNull %3
-%1 = OpVariable %2 Private %5
-%8 = OpTypeFloat 32
-%7 = OpTypeVector %8 3
-)");
- EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
- R"(%9 = OpLoad %3 %1
-%6 = OpConvertUToF %7 %9
-)");
-}
-
TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_U32_to_F16) {
Enable(ast::Extension::kF16);
@@ -4419,6 +4499,34 @@
)");
}
+TEST_F(SpvBuilderConstructorTest, Type_Convert_Vectors_F32_to_F16) {
+ Enable(ast::Extension::kF16);
+
+ auto* var = GlobalVar("i", ty.vec3<f32>(), ast::StorageClass::kPrivate);
+
+ auto* cast = vec3<f16>("i");
+ WrapInFunction(cast);
+
+ spirv::Builder& b = Build();
+
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+ EXPECT_EQ(b.GenerateExpression(cast), 6u) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+%8 = OpTypeFloat 16
+%7 = OpTypeVector %8 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%9 = OpLoad %3 %1
+%6 = OpFConvert %7 %9
+)");
+}
+
TEST_F(SpvBuilderConstructorTest, IsConstructorConst_GlobalVectorWithAllConstConstructors) {
// vec3<f32>(1.0, 2.0, 3.0) -> true
auto* t = vec3<f32>(1_f, 2_f, 3_f);
diff --git a/src/tint/writer/spirv/builder_function_attribute_test.cc b/src/tint/writer/spirv/builder_function_attribute_test.cc
index 554c028..8825c79 100644
--- a/src/tint/writer/spirv/builder_function_attribute_test.cc
+++ b/src/tint/writer/spirv/builder_function_attribute_test.cc
@@ -150,9 +150,9 @@
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9u));
+ Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
+ Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
+ Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
WorkgroupSize("width", "height", "depth"),
@@ -180,7 +180,7 @@
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) {
- Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7u));
+ Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3_i));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
diff --git a/src/tint/writer/spirv/builder_global_variable_test.cc b/src/tint/writer/spirv/builder_global_variable_test.cc
index e4a6a48..dfb38e5 100644
--- a/src/tint/writer/spirv/builder_global_variable_test.cc
+++ b/src/tint/writer/spirv/builder_global_variable_test.cc
@@ -250,7 +250,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Bool) {
- auto* v = Override("var", ty.bool_(), Expr(true), Id(1200));
+ auto* v = Override("var", ty.bool_(), Expr(true), Id(1200_a));
spirv::Builder& b = Build();
@@ -265,7 +265,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Bool_ZeroValue) {
- auto* v = Override("var", ty.bool_(), Construct<bool>(), Id(1200));
+ auto* v = Override("var", ty.bool_(), Construct<bool>(), Id(1200_a));
spirv::Builder& b = Build();
@@ -280,7 +280,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Bool_NoConstructor) {
- auto* v = Override("var", ty.bool_(), Id(1200));
+ auto* v = Override("var", ty.bool_(), Id(1200_a));
spirv::Builder& b = Build();
@@ -295,7 +295,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Scalar) {
- auto* v = Override("var", ty.f32(), Expr(2_f), Id(0));
+ auto* v = Override("var", ty.f32(), Expr(2_f), Id(0_a));
spirv::Builder& b = Build();
@@ -310,7 +310,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Scalar_ZeroValue) {
- auto* v = Override("var", ty.f32(), Construct<f32>(), Id(0));
+ auto* v = Override("var", ty.f32(), Construct<f32>(), Id(0_a));
spirv::Builder& b = Build();
@@ -325,7 +325,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Scalar_F32_NoConstructor) {
- auto* v = Override("var", ty.f32(), Id(0));
+ auto* v = Override("var", ty.f32(), Id(0_a));
spirv::Builder& b = Build();
@@ -340,7 +340,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Scalar_I32_NoConstructor) {
- auto* v = Override("var", ty.i32(), Id(0));
+ auto* v = Override("var", ty.i32(), Id(0_a));
spirv::Builder& b = Build();
@@ -355,7 +355,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_Scalar_U32_NoConstructor) {
- auto* v = Override("var", ty.u32(), Id(0));
+ auto* v = Override("var", ty.u32(), Id(0_a));
spirv::Builder& b = Build();
@@ -370,7 +370,7 @@
}
TEST_F(BuilderTest, GlobalVar_Override_NoId) {
- auto* var_a = Override("a", ty.bool_(), Expr(true), Id(0));
+ auto* var_a = Override("a", ty.bool_(), Expr(true), Id(0_a));
auto* var_b = Override("b", ty.bool_(), Expr(false));
spirv::Builder& b = Build();
diff --git a/src/tint/writer/spirv/generator_impl.cc b/src/tint/writer/spirv/generator_impl.cc
index 8586562..ace5209 100644
--- a/src/tint/writer/spirv/generator_impl.cc
+++ b/src/tint/writer/spirv/generator_impl.cc
@@ -32,6 +32,7 @@
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/var_for_dynamic_index.h"
+#include "src/tint/transform/vectorize_matrix_conversions.h"
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
@@ -78,6 +79,7 @@
manager.Add<transform::SimplifyPointers>(); // Required for arrayLength()
manager.Add<transform::RemovePhonies>();
manager.Add<transform::VectorizeScalarMatrixConstructors>();
+ manager.Add<transform::VectorizeMatrixConversions>();
manager.Add<transform::ForLoopToLoop>(); // Must come after
manager.Add<transform::WhileToLoop>(); // ZeroInitWorkgroupMemory
manager.Add<transform::CanonicalizeEntryPointIO>();
diff --git a/src/tint/writer/spirv/operand.h b/src/tint/writer/spirv/operand.h
index dab10e9..0601ca0 100644
--- a/src/tint/writer/spirv/operand.h
+++ b/src/tint/writer/spirv/operand.h
@@ -43,18 +43,4 @@
} // namespace tint::writer::spirv
-namespace std {
-
-/// Custom std::hash specialization for tint::writer::spirv::Operand
-template <>
-class hash<tint::writer::spirv::Operand> {
- public:
- /// @param o the Operand
- /// @return the hash value
- inline std::size_t operator()(const tint::writer::spirv::Operand& o) const {
- return std::visit([](auto v) { return tint::utils::Hash(v); }, o);
- }
-};
-
-} // namespace std
#endif // SRC_TINT_WRITER_SPIRV_OPERAND_H_
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 755fe6b..bd75fb8 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -776,7 +776,11 @@
return true;
},
[&](const ast::IdAttribute* override_deco) {
- out << "id(" << override_deco->value << ")";
+ out << "id(";
+ if (!EmitExpression(out, override_deco->value)) {
+ return false;
+ }
+ out << ")";
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
diff --git a/src/tint/writer/wgsl/generator_impl_global_decl_test.cc b/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
index 06a5e19..bc77623 100644
--- a/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_global_decl_test.cc
@@ -144,7 +144,7 @@
TEST_F(WgslGeneratorImplTest, Emit_OverridableConstants) {
Override("a", ty.f32());
- Override("b", ty.f32(), Id(7u));
+ Override("b", ty.f32(), Id(7_a));
GeneratorImpl& gen = Build();