Import Tint changes from Dawn
Changes:
- b607bfbddf8d206021f71ba6557cba06b950b5d5 tint/utils: Manually inline HashmapBase::Scan() by Ben Clayton <bclayton@google.com>
- 973a685ad3589dd9c14c0ed306d7ae62648cbfd2 tint/transform: Skip LocalizeStructArrayAssignment if pos... by Ben Clayton <bclayton@google.com>
- 6345562a989f9294a3be4bcd4acece1a0438663a resolver: Delay copy of TemplateState by Ben Clayton <bclayton@google.com>
- 47a81fc126004a876ccb456a1c867f3d33f84eed tint/utils: Use a C-array instead of std::array by Ben Clayton <bclayton@google.com>
- 7092786f313d07d3c73ba02bbe3139d4ecc6ed8e Fixup return of HLSL sign to match WGSL. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: b607bfbddf8d206021f71ba6557cba06b950b5d5
Change-Id: I61b669aab230948f551e070634c117cb058924f9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/116853
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index e49be48..c311597 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -1154,7 +1154,7 @@
Candidate ScoreOverload(const OverloadInfo* overload,
utils::VectorRef<const type::Type*> args,
sem::EvaluationStage earliest_eval_stage,
- TemplateState templates) const;
+ const TemplateState& templates) const;
/// Performs overload resolution given the list of candidates, by ranking the conversions of
/// arguments to the each of the candidate's parameter types.
@@ -1560,7 +1560,7 @@
Impl::Candidate Impl::ScoreOverload(const OverloadInfo* overload,
utils::VectorRef<const type::Type*> args,
sem::EvaluationStage earliest_eval_stage,
- TemplateState templates) const {
+ const TemplateState& in_templates) const {
// Penalty weights for overload mismatching.
// This scoring is used to order the suggested overloads in diagnostic on overload mismatch, and
// has no impact for a correct program.
@@ -1580,6 +1580,10 @@
std::min(num_parameters, num_arguments));
}
+ // Make a mutable copy of the input templates so we can implicitly match more templated
+ // arguments.
+ TemplateState templates(in_templates);
+
// Invoke the matchers for each parameter <-> argument pair.
// If any arguments cannot be matched, then `score` will be increased.
// If the overload has any template types or numbers then these will be set based on the
diff --git a/src/tint/transform/localize_struct_array_assignment.cc b/src/tint/transform/localize_struct_array_assignment.cc
index 8aaa276..3bf1a41 100644
--- a/src/tint/transform/localize_struct_array_assignment.cc
+++ b/src/tint/transform/localize_struct_array_assignment.cc
@@ -47,43 +47,55 @@
utils::Vector<const ast::Statement*, 4> insert_after_stmts;
} s;
- ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt) -> const ast::Statement* {
- // Process if it's an assignment statement to a dynamically indexed array
- // within a struct on a function or private storage variable. This
- // specific use-case is what FXC fails to compile with:
- // error X3500: array reference cannot be used as an l-value; not natively
- // addressable
- if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
- return nullptr;
+ bool made_changes = false;
+
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* assign_stmt = node->As<ast::AssignmentStatement>()) {
+ // Process if it's an assignment statement to a dynamically indexed array
+ // within a struct on a function or private storage variable. This
+ // specific use-case is what FXC fails to compile with:
+ // error X3500: array reference cannot be used as an l-value; not natively
+ // addressable
+ if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
+ continue;
+ }
+ auto og = GetOriginatingTypeAndAddressSpace(assign_stmt);
+ if (!(og.first->Is<sem::Struct>() && (og.second == ast::AddressSpace::kFunction ||
+ og.second == ast::AddressSpace::kPrivate))) {
+ continue;
+ }
+
+ ctx.Replace(assign_stmt, [&, assign_stmt] {
+ // Reset shared state for this assignment statement
+ s = Shared{};
+
+ const ast::Expression* new_lhs = nullptr;
+ {
+ TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
+ new_lhs = ctx.Clone(assign_stmt->lhs);
+ }
+
+ auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
+
+ // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
+ // a block and return it
+ auto stmts = std::move(s.insert_before_stmts);
+ stmts.Reserve(1 + s.insert_after_stmts.Length());
+ stmts.Push(new_assign_stmt);
+ for (auto* stmt : s.insert_after_stmts) {
+ stmts.Push(stmt);
+ }
+
+ return b.Block(std::move(stmts));
+ });
+
+ made_changes = true;
}
- auto og = GetOriginatingTypeAndAddressSpace(assign_stmt);
- if (!(og.first->Is<sem::Struct>() && (og.second == ast::AddressSpace::kFunction ||
- og.second == ast::AddressSpace::kPrivate))) {
- return nullptr;
- }
+ }
- // Reset shared state for this assignment statement
- s = Shared{};
-
- const ast::Expression* new_lhs = nullptr;
- {
- TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
- new_lhs = ctx.Clone(assign_stmt->lhs);
- }
-
- auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
-
- // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
- // a block and return it
- auto stmts = std::move(s.insert_before_stmts);
- stmts.Reserve(1 + s.insert_after_stmts.Length());
- stmts.Push(new_assign_stmt);
- for (auto* stmt : s.insert_after_stmts) {
- stmts.Push(stmt);
- }
-
- return b.Block(std::move(stmts));
- });
+ if (!made_changes) {
+ return SkipTransform;
+ }
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* index_access) -> const ast::Expression* {
diff --git a/src/tint/transform/localize_struct_array_assignment_test.cc b/src/tint/transform/localize_struct_array_assignment_test.cc
index e85a600..9fabd95 100644
--- a/src/tint/transform/localize_struct_array_assignment_test.cc
+++ b/src/tint/transform/localize_struct_array_assignment_test.cc
@@ -25,9 +25,7 @@
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
auto* src = R"()";
- auto* expect = src;
- auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_FALSE(ShouldRun<LocalizeStructArrayAssignment>(src));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArray) {
@@ -842,10 +840,7 @@
// Transform does nothing here as we're not actually assigning to the array in
// the struct.
- auto* expect = src;
-
- auto got = Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
- EXPECT_EQ(expect, str(got));
+ EXPECT_FALSE(ShouldRun<LocalizeStructArrayAssignment>(src));
}
} // namespace
diff --git a/src/tint/utils/hashmap_base.h b/src/tint/utils/hashmap_base.h
index cd4ef6a..959bebe 100644
--- a/src/tint/utils/hashmap_base.h
+++ b/src/tint/utils/hashmap_base.h
@@ -291,24 +291,26 @@
// Shuffle the entries backwards until we either find a free slot, or a slot that has zero
// distance.
Slot* prev = nullptr;
- Scan(start, [&](size_t, size_t index) {
+
+ const auto count = slots_.Length();
+ for (size_t distance = 0, index = start; distance < count; distance++) {
auto& slot = slots_[index];
if (prev) {
// note: `distance == 0` also includes empty slots.
if (slot.distance == 0) {
// Clear the previous slot, and stop shuffling.
*prev = {};
- return Action::kStop;
- } else {
- // Shuffle the slot backwards.
- prev->entry = std::move(slot.entry);
- prev->hash = slot.hash;
- prev->distance = slot.distance - 1;
+ break;
}
+ // Shuffle the slot backwards.
+ prev->entry = std::move(slot.entry);
+ prev->hash = slot.hash;
+ prev->distance = slot.distance - 1;
}
prev = &slot;
- return Action::kContinue;
- });
+
+ index = (index == count - 1) ? 0 : index + 1;
+ }
// Entry was removed.
count_--;
@@ -438,8 +440,8 @@
}
};
- PutResult result{};
- Scan(hash.scan_start, [&](size_t distance, size_t index) {
+ const auto count = slots_.Length();
+ for (size_t distance = 0, index = hash.scan_start; distance < count; distance++) {
auto& slot = slots_[index];
if (!slot.entry.has_value()) {
// Found an empty slot.
@@ -449,8 +451,7 @@
slot.distance = distance;
count_++;
generation_++;
- result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
- return Action::kStop;
+ return PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
}
// Slot has an entry
@@ -460,11 +461,10 @@
if constexpr (MODE == PutMode::kReplace) {
slot.entry = make_entry();
generation_++;
- result = PutResult{MapAction::kReplaced, ValueOf(*slot.entry)};
+ return PutResult{MapAction::kReplaced, ValueOf(*slot.entry)};
} else {
- result = PutResult{MapAction::kKeptExisting, ValueOf(*slot.entry)};
+ return PutResult{MapAction::kKeptExisting, ValueOf(*slot.entry)};
}
- return Action::kStop;
}
if (slot.distance < distance) {
@@ -480,47 +480,15 @@
count_++;
generation_++;
- result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
-
- return Action::kStop;
+ return PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
}
- return Action::kContinue;
- });
- return result;
- }
-
- /// Return type of the Scan() callback.
- enum class Action {
- /// Continue scanning for a slot
- kContinue,
- /// Immediately stop scanning for a slot
- kStop,
- };
-
- /// Sequentially visits each of the slots starting with the slot with the index @p start,
- /// calling the callback function @p f for each slot until @p f returns Action::kStop.
- /// @param start the index of the first slot to start scanning from.
- /// @param f the callback function which:
- /// * must be a function with the signature `Action(size_t distance, size_t index)`.
- /// * must return Action::kStop within one whole cycle of the slots.
- template <typename F>
- void Scan(size_t start, F&& f) const {
- size_t distance = 0;
- for (size_t index = start; index < slots_.Length(); index++) {
- if (f(distance, index) == Action::kStop) {
- return;
- }
- distance++;
+ index = (index == count - 1) ? 0 : index + 1;
}
- for (size_t index = 0; index < start; index++) {
- if (f(distance, index) == Action::kStop) {
- return;
- }
- distance++;
- }
+
tint::diag::List diags;
- TINT_ICE(Utils, diags) << "HashmapBase::Scan() looped entire map without finding a slot";
+ TINT_ICE(Utils, diags) << "HashmapBase::Put() looped entire map without finding a slot";
+ return PutResult{};
}
/// HashResult is the return value of Hash()
@@ -546,45 +514,44 @@
/// if found, the index of the slot that holds the key.
std::tuple<bool, size_t> IndexOf(const Key& key) const {
const auto hash = Hash(key);
-
- bool found = false;
- size_t idx = 0;
-
- Scan(hash.scan_start, [&](size_t distance, size_t index) {
+ const auto count = slots_.Length();
+ for (size_t distance = 0, index = hash.scan_start; distance < count; distance++) {
auto& slot = slots_[index];
if (!slot.entry.has_value()) {
- return Action::kStop;
+ return {/* found */ false, /* index */ 0};
}
if (slot.Equals(hash.code, key)) {
- found = true;
- idx = index;
- return Action::kStop;
+ return {/* found */ true, index};
}
if (slot.distance < distance) {
- // If the slot distance is less than the current probe distance, then the slot must
- // be for entry that has an index that comes after key. In this situation, we know
- // that the map does not contain the key, as it would have been found before this
- // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html
- // suggests that the condition should inverted, but this is wrong.
- return Action::kStop;
+ // If the slot distance is less than the current probe distance, then the slot
+ // must be for entry that has an index that comes after key. In this situation,
+ // we know that the map does not contain the key, as it would have been found
+ // before this slot. The "Lookup" section of
+ // https://programming.guide/robin-hood-hashing.html suggests that the condition
+ // should inverted, but this is wrong.
+ return {/* found */ false, /* index */ 0};
}
- return Action::kContinue;
- });
+ index = (index == count - 1) ? 0 : index + 1;
+ }
- return {found, idx};
+ tint::diag::List diags;
+ TINT_ICE(Utils, diags) << "HashmapBase::IndexOf() looped entire map without finding a slot";
+ return {/* found */ false, /* index */ 0};
}
/// Shuffles slots for an insertion that has been placed one slot before `start`.
/// @param start the index of the first slot to start shuffling.
/// @param evicted the slot content that was evicted for the insertion.
void InsertShuffle(size_t start, Slot&& evicted) {
- Scan(start, [&](size_t, size_t index) {
+ const auto count = slots_.Length();
+ for (size_t distance = 0, index = start; distance < count; distance++) {
auto& slot = slots_[index];
if (!slot.entry.has_value()) {
// Empty slot found for evicted.
slot = std::move(evicted);
- return Action::kStop; // We're done.
+ return; // We're done.
}
if (slot.distance < evicted.distance) {
@@ -596,8 +563,8 @@
// evicted moves further from the target slot...
evicted.distance++;
- return Action::kContinue;
- });
+ index = (index == count - 1) ? 0 : index + 1;
+ }
}
/// @param count the number of new entries in the map
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index 83e1c12..cca2409 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -18,7 +18,6 @@
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
-#include <array>
#include <iterator>
#include <ostream>
#include <utility>
@@ -592,7 +591,7 @@
/// The internal structure for the vector with a small array.
struct ImplWithSmallArray {
- std::array<TStorage, N> small_arr;
+ TStorage small_arr[N];
Slice slice = {small_arr[0].Get(), 0, N};
/// Allocates a new vector of `T` either from #small_arr, or from the heap, then assigns the
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 8e8f1ed..fe10b2b 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -959,6 +959,9 @@
if (type == sem::BuiltinType::kRadians) {
return EmitRadiansCall(out, expr, builtin);
}
+ if (type == sem::BuiltinType::kSign) {
+ return EmitSignCall(out, call, builtin);
+ }
if (type == sem::BuiltinType::kQuantizeToF16) {
return EmitQuantizeToF16Call(out, expr, builtin);
}
@@ -2065,6 +2068,22 @@
});
}
+// The HLSL `sign` method always returns an `int` result (scalar or vector). In WGSL the result is
+// expected to be the same type as the argument. This injects a cast to the expected WGSL result
+// type after the call to `sign`.
+bool GeneratorImpl::EmitSignCall(std::ostream& out, const sem::Call* call, const sem::Builtin*) {
+ auto* arg = call->Arguments()[0];
+ if (!EmitType(out, arg->Type(), ast::AddressSpace::kNone, ast::Access::kReadWrite, "")) {
+ return false;
+ }
+ out << "(sign(";
+ if (!EmitExpression(out, arg->Declaration())) {
+ return false;
+ }
+ out << "))";
+ return true;
+}
+
bool GeneratorImpl::EmitQuantizeToF16Call(std::ostream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin) {
@@ -2662,7 +2681,6 @@
case sem::BuiltinType::kRefract:
case sem::BuiltinType::kRound:
case sem::BuiltinType::kSaturate:
- case sem::BuiltinType::kSign:
case sem::BuiltinType::kSin:
case sem::BuiltinType::kSinh:
case sem::BuiltinType::kSqrt:
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index 1558fa5..76b9042 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -244,6 +244,12 @@
bool EmitRadiansCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin);
+ /// Handles generating a call to the `sign()` builtin
+ /// @param out the output of the expression stream
+ /// @param call the call semantic node
+ /// @param builtin the semantic information for the builtin
+ /// @returns true if the call expression is emitted
+ bool EmitSignCall(std::ostream& out, const sem::Call* call, const sem::Builtin* builtin);
/// Handles generating a call to data packing builtin
/// @param out the output of the expression stream
/// @param expr the call expression
diff --git a/src/tint/writer/hlsl/generator_impl_builtin_test.cc b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
index 8b5fe07..785e1fd 100644
--- a/src/tint/writer/hlsl/generator_impl_builtin_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_builtin_test.cc
@@ -98,7 +98,6 @@
case BuiltinType::kTan:
case BuiltinType::kTanh:
case BuiltinType::kTrunc:
- case BuiltinType::kSign:
if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else {
@@ -294,8 +293,6 @@
BuiltinData{BuiltinType::kPow, CallParamType::kF16, "pow"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF32, "reflect"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF16, "reflect"},
- BuiltinData{BuiltinType::kSign, CallParamType::kF32, "sign"},
- BuiltinData{BuiltinType::kSign, CallParamType::kF16, "sign"},
BuiltinData{BuiltinType::kSin, CallParamType::kF32, "sin"},
BuiltinData{BuiltinType::kSin, CallParamType::kF16, "sin"},
BuiltinData{BuiltinType::kSinh, CallParamType::kF32, "sinh"},
@@ -1002,6 +999,112 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Scalar_i32) {
+ auto* val = Var("val", ty.i32());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ int val = 0;
+ const int tint_symbol = int(sign(val));
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Vector_i32) {
+ auto* val = Var("val", ty.vec3<i32>());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ int3 val = int3(0, 0, 0);
+ const int3 tint_symbol = int3(sign(val));
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Scalar_f32) {
+ auto* val = Var("val", ty.f32());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ float val = 0.0f;
+ const float tint_symbol = float(sign(val));
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Vector_f32) {
+ auto* val = Var("val", ty.vec3<f32>());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ float3 val = float3(0.0f, 0.0f, 0.0f);
+ const float3 tint_symbol = float3(sign(val));
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Scalar_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* val = Var("val", ty.f16());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ float16_t val = float16_t(0.0h);
+ const float16_t tint_symbol = float16_t(sign(val));
+ return;
+}
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Builtin, Sign_Vector_f16) {
+ Enable(ast::Extension::kF16);
+
+ auto* val = Var("val", ty.vec3<f16>());
+ auto* call = Call("sign", val);
+ WrapInFunction(val, call);
+
+ GeneratorImpl& gen = SanitizeAndBuild();
+
+ ASSERT_TRUE(gen.Generate()) << gen.error();
+ EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
+void test_function() {
+ vector<float16_t, 3> val = vector<float16_t, 3>(float16_t(0.0h), float16_t(0.0h), float16_t(0.0h));
+ const vector<float16_t, 3> tint_symbol = vector<float16_t, 3>(sign(val));
+ return;
+}
+)");
+}
+
TEST_F(HlslGeneratorImplTest_Builtin, Pack4x8Snorm) {
auto* call = Call("pack4x8snorm", "p1");
GlobalVar("p1", ty.vec4<f32>(), ast::AddressSpace::kPrivate);
diff --git a/src/tint/writer/hlsl/generator_impl_import_test.cc b/src/tint/writer/hlsl/generator_impl_import_test.cc
index 96ed7ee..3e9e591 100644
--- a/src/tint/writer/hlsl/generator_impl_import_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_import_test.cc
@@ -62,7 +62,6 @@
HlslImportData{"log", "log"},
HlslImportData{"log2", "log2"},
HlslImportData{"round", "round"},
- HlslImportData{"sign", "sign"},
HlslImportData{"sin", "sin"},
HlslImportData{"sinh", "sinh"},
HlslImportData{"sqrt", "sqrt"},
@@ -121,7 +120,6 @@
HlslImportData{"log2", "log2"},
HlslImportData{"normalize", "normalize"},
HlslImportData{"round", "round"},
- HlslImportData{"sign", "sign"},
HlslImportData{"sin", "sin"},
HlslImportData{"sinh", "sinh"},
HlslImportData{"sqrt", "sqrt"},