tint/utils: Fix Hashmap::GetOrCreate() for map mutation in create
Its not unreasonable for the create callback to mutate the map. If this
happened, the map would be corrupted.
This change fixes this.
Change-Id: I2bb3820061c741c6da36ebe3667cb6b878515a27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/100903
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
index 1154086..87a0dd9 100644
--- a/src/tint/utils/hashmap.h
+++ b/src/tint/utils/hashmap.h
@@ -35,28 +35,6 @@
typename HASH = Hasher<K>,
typename EQUAL = std::equal_to<K>>
class Hashmap {
- /// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into
- /// the underlying Hashset.
- ///
- /// LazyCreator holds a #key, and a #create function used to build the final Entry::value.
- /// The #create function must be of the signature `V()`.
- ///
- /// LazyCreator can be compared to Entry and hashed, allowing them to be passed to
- /// Hashset::Insert(). If the set does not contain an existing entry with #key,
- /// Hashset::Insert() will construct a new Entry passing the rvalue LazyCreator as the
- /// constructor argument, which in turn calls the #create function to generate the entry value.
- ///
- /// @see Entry
- /// @see Hasher
- /// @see Equality
- template <typename CREATE>
- struct LazyCreator {
- /// The key of the entry to insert into the map
- const K& key;
- /// The value creation function
- CREATE create;
- };
-
/// Entry holds a key and value pair, and is used as the element type of the underlying Hashset.
/// Entries are compared and hashed using only the #key.
/// @see Hasher
@@ -71,23 +49,6 @@
/// Move-constructor.
Entry(Entry&&) = default;
- /// Constructor from a LazyCreator.
- /// The constructor invokes the LazyCreator::create function to build the #value.
- /// @see LazyCreator
- template <typename CREATE>
- Entry(const LazyCreator<CREATE>& creator) // NOLINT(runtime/explicit)
- : key(creator.key), value(creator.create()) {}
-
- /// Assignment operator from a LazyCreator.
- /// The assignment invokes the LazyCreator::create function to build the #value.
- /// @see LazyCreator
- template <typename CREATE>
- Entry& operator=(LazyCreator<CREATE>&& creator) {
- key = std::move(creator.key);
- value = creator.create();
- return *this;
- }
-
/// Copy-assignment operator
Entry& operator=(const Entry&) = default;
@@ -99,33 +60,23 @@
};
/// Hash provider for the underlying Hashset.
- /// Provides hash functions for an Entry, K or LazyCreator.
+ /// Provides hash functions for an Entry or K.
/// The hash functions only consider the key of an entry.
struct Hasher {
/// Calculates a hash from an Entry
size_t operator()(const Entry& entry) const { return HASH()(entry.key); }
/// Calculates a hash from a K
size_t operator()(const K& key) const { return HASH()(key); }
- /// Calculates a hash from a LazyCreator
- template <typename CREATE>
- size_t operator()(const LazyCreator<CREATE>& lc) const {
- return HASH()(lc.key);
- }
};
/// Equality provider for the underlying Hashset.
- /// Provides equality functions for an Entry, K or LazyCreator to an Entry.
+ /// Provides equality functions for an Entry or K to an Entry.
/// The equality functions only consider the key for equality.
struct Equality {
/// Compares an Entry to an Entry for equality.
bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); }
/// Compares a K to an Entry for equality.
bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); }
- /// Compares a LazyCreator to an Entry for equality.
- template <typename CREATE>
- bool operator()(const LazyCreator<CREATE>& lc, const Entry& b) const {
- return EQUAL()(lc.key, b.key);
- }
};
/// The underlying set
@@ -151,7 +102,8 @@
/// Used by gmock for the `ElementsAre` checks.
using value_type = KeyValue;
- /// Iterator for the map
+ /// Iterator for the map.
+ /// Iterators are invalidated if the map is modified.
class Iterator {
public:
/// @returns the key of the entry pointed to by this iterator
@@ -226,13 +178,29 @@
/// Searches for an entry with the given key value, adding and returning the result of
/// calling `create` if the entry was not found.
+ /// @note: Before calling `create`, the map will insert a zero-initialized value for the given
+ /// key, which will be replaced with the value returned by `create`. If `create` adds an entry
+ /// with `key` to this map, it will be replaced.
/// @param key the entry's key value to search for.
/// @param create the create function to call if the map does not contain the key.
/// @returns the value of the entry.
template <typename CREATE>
V& GetOrCreate(const K& key, CREATE&& create) {
- LazyCreator<CREATE> lc{key, std::forward<CREATE>(create)};
- auto res = set_.Add(std::move(lc));
+ auto res = set_.Add(Entry{key, V{}});
+ if (res.action == AddAction::kAdded) {
+ // Store the set generation before calling create()
+ auto generation = set_.Generation();
+ // Call create(), which might modify this map.
+ auto value = create();
+ // Was this map mutated?
+ if (set_.Generation() == generation) {
+ // Calling create() did not touch the map. No need to lookup again.
+ res.entry->value = std::move(value);
+ } else {
+ // Calling create() modified the map. Need to insert again.
+ res = set_.Replace(Entry{key, std::move(value)});
+ }
+ }
return res.entry->value;
}
@@ -241,9 +209,7 @@
/// @param key the entry's key value to search for.
/// @returns the value of the entry.
V& GetOrZero(const K& key) {
- auto zero = [] { return V{}; };
- LazyCreator<decltype(zero)> lc{key, zero};
- auto res = set_.Add(std::move(lc));
+ auto res = set_.Add(Entry{key, V{}});
return res.entry->value;
}
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
index e52b144..9a5b01e 100644
--- a/src/tint/utils/hashmap_test.cc
+++ b/src/tint/utils/hashmap_test.cc
@@ -119,9 +119,16 @@
TEST(Hashmap, GetOrCreate) {
Hashmap<int, std::string, 8> map;
- EXPECT_EQ(map.GetOrCreate(0, [&] { return "zero"; }), "zero");
+ std::optional<std::string> value_of_key_0_at_create;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ value_of_key_0_at_create = map.Get(0);
+ return "zero";
+ }),
+ "zero");
EXPECT_EQ(map.Count(), 1u);
EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(value_of_key_0_at_create, "");
bool create_called = false;
EXPECT_EQ(map.GetOrCreate(0,
@@ -139,6 +146,67 @@
EXPECT_EQ(map.Get(1), "one");
}
+TEST(Hashmap, GetOrCreate_CreateModifiesMap) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ map.Add(3, "three");
+ map.Add(1, "one");
+ map.Add(2, "two");
+ return "zero";
+ }),
+ "zero");
+ EXPECT_EQ(map.Count(), 4u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+
+ bool create_called = false;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ create_called = true;
+ return "oh noes";
+ }),
+ "zero");
+ EXPECT_FALSE(create_called);
+ EXPECT_EQ(map.Count(), 4u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+
+ EXPECT_EQ(map.GetOrCreate(4,
+ [&] {
+ map.Add(6, "six");
+ map.Add(5, "five");
+ map.Add(7, "seven");
+ return "four";
+ }),
+ "four");
+ EXPECT_EQ(map.Count(), 8u);
+ EXPECT_EQ(map.Get(0), "zero");
+ EXPECT_EQ(map.Get(1), "one");
+ EXPECT_EQ(map.Get(2), "two");
+ EXPECT_EQ(map.Get(3), "three");
+ EXPECT_EQ(map.Get(4), "four");
+ EXPECT_EQ(map.Get(5), "five");
+ EXPECT_EQ(map.Get(6), "six");
+ EXPECT_EQ(map.Get(7), "seven");
+}
+
+TEST(Hashmap, GetOrCreate_CreateAddsSameKeyedValue) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.GetOrCreate(42,
+ [&] {
+ map.Add(42, "should-be-replaced");
+ return "expected-value";
+ }),
+ "expected-value");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_EQ(map.Get(42), "expected-value");
+}
+
TEST(Hashmap, Soak) {
std::mt19937 rnd;
std::unordered_map<std::string, std::string> reference;