Fixes cache key generation to handle binary data.

Bug: dawn:549
Change-Id: Ie6b3ceb610b362adfed96a0982d7541002660809
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/84920
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/src/dawn/native/CacheKey.cpp b/src/dawn/native/CacheKey.cpp
index ff2cec0..3495577 100644
--- a/src/dawn/native/CacheKey.cpp
+++ b/src/dawn/native/CacheKey.cpp
@@ -18,15 +18,14 @@
 
     template <>
     void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
-        std::string len = std::to_string(t.length());
-        key->insert(key->end(), len.begin(), len.end());
-        key->push_back('"');
+        key->Record(static_cast<size_t>(t.length()));
         key->insert(key->end(), t.begin(), t.end());
-        key->push_back('"');
     }
 
     template <>
     void CacheKeySerializer<CacheKey>::Serialize(CacheKey* key, const CacheKey& t) {
+        // For nested cache keys, we do not record the length, and just copy the key so that it
+        // appears we just flatten the keys into a single key.
         key->insert(key->end(), t.begin(), t.end());
     }
 
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index cb58711..ce21f6d 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -15,79 +15,84 @@
 #ifndef DAWNNATIVE_CACHE_KEY_H_
 #define DAWNNATIVE_CACHE_KEY_H_
 
+#include <limits>
 #include <string>
+#include <type_traits>
 #include <vector>
 
-#include "dawn/common/Compiler.h"
+#include "dawn/common/Assert.h"
 
 namespace dawn::native {
 
-    using CacheKey = std::vector<uint8_t>;
+    // Forward declare CacheKey class because of co-dependency.
+    class CacheKey;
 
     // Overridable serializer struct that should be implemented for cache key serializable
     // types/classes.
     template <typename T, typename SFINAE = void>
-    struct CacheKeySerializer {
+    class CacheKeySerializer {
+      public:
         static void Serialize(CacheKey* key, const T& t);
     };
 
-    // Specialized overload for integral types. Note that we are currently serializing as a string
-    // to avoid handling null termiantors.
-    template <typename Integer>
-    struct CacheKeySerializer<Integer, std::enable_if_t<std::is_integral_v<Integer>>> {
-        static void Serialize(CacheKey* key, const Integer i) {
-            std::string str = std::to_string(i);
-            key->insert(key->end(), str.begin(), str.end());
+    class CacheKey : public std::vector<uint8_t> {
+      public:
+        using std::vector<uint8_t>::vector;
+
+        template <typename T>
+        CacheKey& Record(const T& t) {
+            CacheKeySerializer<T>::Serialize(this, t);
+            return *this;
+        }
+        template <typename T, typename... Args>
+        CacheKey& Record(const T& t, const Args&... args) {
+            CacheKeySerializer<T>::Serialize(this, t);
+            return Record(args...);
+        }
+
+        // Records iterables by prepending the number of elements. Some common iterables are have a
+        // CacheKeySerializer implemented to avoid needing to split them out when recording, i.e.
+        // strings and CacheKeys, but they fundamentally do the same as this function.
+        template <typename IterableT>
+        CacheKey& RecordIterable(const IterableT& iterable) {
+            // Always record the size of generic iterables as a size_t for now.
+            Record(static_cast<size_t>(iterable.size()));
+            for (auto it = iterable.begin(); it != iterable.end(); ++it) {
+                Record(*it);
+            }
+            return *this;
+        }
+        template <typename Ptr>
+        CacheKey& RecordIterable(const Ptr* ptr, size_t n) {
+            Record(n);
+            for (size_t i = 0; i < n; ++i) {
+                Record(ptr[i]);
+            }
+            return *this;
         }
     };
 
-    // Specialized overload for floating point types. Note that we are currently serializing as a
-    // string to avoid handling null termiantors.
-    template <typename Float>
-    struct CacheKeySerializer<Float, std::enable_if_t<std::is_floating_point_v<Float>>> {
-        static void Serialize(CacheKey* key, const Float f) {
-            std::string str = std::to_string(f);
-            key->insert(key->end(), str.begin(), str.end());
+    // Specialized overload for fundamental types.
+    template <typename T>
+    class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
+      public:
+        static void Serialize(CacheKey* key, const T t) {
+            const char* it = reinterpret_cast<const char*>(&t);
+            key->insert(key->end(), it, (it + sizeof(T)));
         }
     };
 
     // Specialized overload for string literals. Note we drop the null-terminator.
     template <size_t N>
-    struct CacheKeySerializer<char[N]> {
+    class CacheKeySerializer<char[N]> {
+      public:
         static void Serialize(CacheKey* key, const char (&t)[N]) {
-            std::string len = std::to_string(N - 1);
-            key->insert(key->end(), len.begin(), len.end());
-            key->push_back('"');
-            key->insert(key->end(), t, t + N - 1);
-            key->push_back('"');
+            static_assert(N > 0);
+            key->Record(static_cast<size_t>(N));
+            key->insert(key->end(), t, t + N);
         }
     };
 
-    // Helper template function that defers to underlying static functions.
-    template <typename T>
-    void SerializeInto(CacheKey* key, const T& t) {
-        CacheKeySerializer<T>::Serialize(key, t);
-    }
-
-    // Given list of arguments of types with a free implementation of SerializeIntoImpl in the
-    // dawn::native namespace, serializes each argument and appends them to the CacheKey while
-    // prepending member ids before each argument.
-    template <typename... Ts>
-    CacheKey GetCacheKey(const Ts&... inputs) {
-        CacheKey key;
-        key.push_back('{');
-        int memberId = 0;
-        auto Serialize = [&](const auto& input) {
-            std::string memberIdStr = (memberId == 0 ? "" : ",") + std::to_string(memberId) + ":";
-            key.insert(key.end(), memberIdStr.begin(), memberIdStr.end());
-            SerializeInto(&key, input);
-            memberId++;
-        };
-        (Serialize(inputs), ...);
-        key.push_back('}');
-        return key;
-    }
-
 }  // namespace dawn::native
 
 #endif  // DAWNNATIVE_CACHE_KEY_H_
diff --git a/src/dawn/native/CachedObject.cpp b/src/dawn/native/CachedObject.cpp
index 538b7b5c..e7e7cd8 100644
--- a/src/dawn/native/CachedObject.cpp
+++ b/src/dawn/native/CachedObject.cpp
@@ -42,26 +42,12 @@
         mIsContentHashInitialized = true;
     }
 
-    const std::string& CachedObject::GetCacheKey() const {
-        ASSERT(mIsCacheKeyBaseInitialized);
-        return mCacheKeyBase;
+    const CacheKey& CachedObject::GetCacheKey() const {
+        return mCacheKey;
     }
 
-    std::string CachedObject::GetCacheKey(DeviceBase* device) const {
-        ASSERT(mIsCacheKeyBaseInitialized);
-        // TODO(dawn:549) Prepend/append with device/adapter information.
-        return mCacheKeyBase;
-    }
-
-    void CachedObject::SetCacheKey(const std::string& cacheKey) {
-        ASSERT(!mIsContentHashInitialized);
-        mCacheKeyBase = cacheKey;
-        mIsCacheKeyBaseInitialized = true;
-    }
-
-    std::string CachedObject::ComputeCacheKeyBase() const {
-        // This implementation should never be called. Only overrides should be called.
-        UNREACHABLE();
+    CacheKey* CachedObject::GetCacheKey() {
+        return &mCacheKey;
     }
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/CachedObject.h b/src/dawn/native/CachedObject.h
index 3cf79a2..7d28ae8 100644
--- a/src/dawn/native/CachedObject.h
+++ b/src/dawn/native/CachedObject.h
@@ -15,6 +15,7 @@
 #ifndef DAWNNATIVE_CACHED_OBJECT_H_
 #define DAWNNATIVE_CACHED_OBJECT_H_
 
+#include "dawn/native/CacheKey.h"
 #include "dawn/native/Forward.h"
 
 #include <cstddef>
@@ -38,13 +39,12 @@
         size_t GetContentHash() const;
         void SetContentHash(size_t contentHash);
 
-        // Two versions of GetCacheKey, when passed a device, prepends the stored cache
-        // key base with device and adapter information. When called without passing a
-        // device, returns the stored cache key base. This is useful when the instance
-        // is a member to a parent class.
-        const std::string& GetCacheKey() const;
-        std::string GetCacheKey(DeviceBase* device) const;
-        void SetCacheKey(const std::string& cacheKey);
+        // Returns the cache key for the object only, i.e. without device/adapter information.
+        const CacheKey& GetCacheKey() const;
+
+      protected:
+        // Protected accessor for derived classes to access and modify the key.
+        CacheKey* GetCacheKey();
 
       private:
         friend class DeviceBase;
@@ -55,13 +55,9 @@
         // Called by ObjectContentHasher upon creation to record the object.
         virtual size_t ComputeContentHash() = 0;
 
-        // Not all classes implement cache key computation, so by default we assert.
-        virtual std::string ComputeCacheKeyBase() const;
-
         size_t mContentHash = 0;
         bool mIsContentHashInitialized = false;
-        std::string mCacheKeyBase = "";
-        bool mIsCacheKeyBaseInitialized = false;
+        CacheKey mCacheKey;
     };
 
 }  // namespace dawn::native
diff --git a/src/dawn/tests/unittests/native/CacheKeyTests.cpp b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
index 3228e1e..45fd360 100644
--- a/src/dawn/tests/unittests/native/CacheKeyTests.cpp
+++ b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
@@ -15,73 +15,168 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
+#include <cstring>
+#include <iomanip>
 #include <string>
 
 #include "dawn/native/CacheKey.h"
 
 namespace dawn::native {
 
-    // Testing classes/structs with serializing implemented for testing.
-    struct A {};
+    // Testing classes with mock serializing implemented for testing.
+    class A {
+      public:
+        MOCK_METHOD(void, SerializeMock, (CacheKey*, const A&), (const));
+    };
     template <>
     void CacheKeySerializer<A>::Serialize(CacheKey* key, const A& t) {
-        std::string str = "structA";
-        key->insert(key->end(), str.begin(), str.end());
+        t.SerializeMock(key, t);
     }
 
-    class B {};
-    template <>
-    void CacheKeySerializer<B>::Serialize(CacheKey* key, const B& t) {
-        std::string str = "classB";
-        key->insert(key->end(), str.begin(), str.end());
+    // Custom printer for CacheKey for clearer debug testing messages.
+    void PrintTo(const CacheKey& key, std::ostream* stream) {
+        *stream << std::hex;
+        for (const int b : key) {
+            *stream << std::setfill('0') << std::setw(2) << b << " ";
+        }
+        *stream << std::dec;
     }
 
     namespace {
 
-        // Matcher to compare CacheKey to a string for easier testing.
-        MATCHER_P(CacheKeyEq,
-                  key,
-                  "cache key " + std::string(negation ? "not" : "") + "equal to " + key) {
-            return std::string(arg.begin(), arg.end()) == key;
+        using ::testing::InSequence;
+        using ::testing::NotNull;
+        using ::testing::PrintToString;
+        using ::testing::Ref;
+
+        // Matcher to compare CacheKeys for easier testing.
+        MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
+            return memcmp(arg.data(), key.data(), arg.size()) == 0;
         }
 
-        TEST(CacheKeyTest, IntegralTypes) {
-            EXPECT_THAT(GetCacheKey((int)-1), CacheKeyEq("{0:-1}"));
-            EXPECT_THAT(GetCacheKey((uint8_t)2), CacheKeyEq("{0:2}"));
-            EXPECT_THAT(GetCacheKey((uint16_t)4), CacheKeyEq("{0:4}"));
-            EXPECT_THAT(GetCacheKey((uint32_t)8), CacheKeyEq("{0:8}"));
-            EXPECT_THAT(GetCacheKey((uint64_t)16), CacheKeyEq("{0:16}"));
+        TEST(CacheKeyTests, RecordSingleMember) {
+            CacheKey key;
 
-            EXPECT_THAT(GetCacheKey((int)-1, (uint8_t)2, (uint16_t)4, (uint32_t)8, (uint64_t)16),
-                        CacheKeyEq("{0:-1,1:2,2:4,3:8,4:16}"));
+            A a;
+            EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+            EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey()));
         }
 
-        TEST(CacheKeyTest, FloatingTypes) {
-            EXPECT_THAT(GetCacheKey((float)0.5), CacheKeyEq("{0:0.500000}"));
-            EXPECT_THAT(GetCacheKey((double)32.0), CacheKeyEq("{0:32.000000}"));
+        TEST(CacheKeyTests, RecordManyMembers) {
+            constexpr size_t kNumMembers = 100;
 
-            EXPECT_THAT(GetCacheKey((float)0.5, (double)32.0),
-                        CacheKeyEq("{0:0.500000,1:32.000000}"));
+            CacheKey key;
+            for (size_t i = 0; i < kNumMembers; ++i) {
+                A a;
+                EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                key.Record(a);
+            }
+            EXPECT_THAT(key, CacheKeyEq(CacheKey()));
         }
 
-        TEST(CacheKeyTest, Strings) {
-            std::string str0 = "string0";
-            std::string str1 = "string1";
+        TEST(CacheKeyTests, RecordIterable) {
+            constexpr size_t kIterableSize = 100;
 
-            EXPECT_THAT(GetCacheKey("string0"), CacheKeyEq(R"({0:7"string0"})"));
-            EXPECT_THAT(GetCacheKey(str0), CacheKeyEq(R"({0:7"string0"})"));
-            EXPECT_THAT(GetCacheKey("string0", str1), CacheKeyEq(R"({0:7"string0",1:7"string1"})"));
+            // Expecting the size of the container.
+            CacheKey expected;
+            expected.Record(kIterableSize);
+
+            std::vector<A> iterable(kIterableSize);
+            {
+                InSequence seq;
+                for (const auto& a : iterable) {
+                    EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                }
+                for (const auto& a : iterable) {
+                    EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                }
+            }
+
+            EXPECT_THAT(CacheKey().RecordIterable(iterable), CacheKeyEq(expected));
+            EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize),
+                        CacheKeyEq(expected));
         }
 
-        TEST(CacheKeyTest, NestedCacheKey) {
-            EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1)), CacheKeyEq("{0:{0:-1}}"));
-            EXPECT_THAT(GetCacheKey(GetCacheKey("string")), CacheKeyEq(R"({0:{0:6"string"}})"));
-            EXPECT_THAT(GetCacheKey(GetCacheKey(A{})), CacheKeyEq("{0:{0:structA}}"));
-            EXPECT_THAT(GetCacheKey(GetCacheKey(B())), CacheKeyEq("{0:{0:classB}}"));
+        TEST(CacheKeyTests, RecordNested) {
+            CacheKey expected;
+            CacheKey actual;
+            {
+                // Recording a single member.
+                A a;
+                EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                actual.Record(CacheKey().Record(a));
+            }
+            {
+                // Recording multiple members.
+                constexpr size_t kNumMembers = 2;
+                CacheKey sub;
+                for (size_t i = 0; i < kNumMembers; ++i) {
+                    A a;
+                    EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                    sub.Record(a);
+                }
+                actual.Record(sub);
+            }
+            {
+                // Record an iterable.
+                constexpr size_t kIterableSize = 2;
+                expected.Record(kIterableSize);
+                std::vector<A> iterable(kIterableSize);
+                {
+                    InSequence seq;
+                    for (const auto& a : iterable) {
+                        EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
+                    }
+                }
+                actual.Record(CacheKey().RecordIterable(iterable));
+            }
+            EXPECT_THAT(actual, CacheKeyEq(expected));
+        }
 
-            EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1), GetCacheKey("string"), GetCacheKey(A{}),
-                                    GetCacheKey(B())),
-                        CacheKeyEq(R"({0:{0:-1},1:{0:6"string"},2:{0:structA},3:{0:classB}})"));
+        TEST(CacheKeySerializerTests, IntegralTypes) {
+            // Only testing explicitly sized types for simplicity, and using 0s for larger types to
+            // avoid dealing with endianess.
+            EXPECT_THAT(CacheKey().Record('c'), CacheKeyEq(CacheKey({'c'})));
+            EXPECT_THAT(CacheKey().Record(uint8_t(255)), CacheKeyEq(CacheKey({255})));
+            EXPECT_THAT(CacheKey().Record(uint16_t(0)), CacheKeyEq(CacheKey({0, 0})));
+            EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0})));
+        }
+
+        TEST(CacheKeySerializerTests, FloatingTypes) {
+            // Using 0s to avoid dealing with implementation specific float details.
+            EXPECT_THAT(CacheKey().Record(float(0)), CacheKeyEq(CacheKey(sizeof(float), 0)));
+            EXPECT_THAT(CacheKey().Record(double(0)), CacheKeyEq(CacheKey(sizeof(double), 0)));
+        }
+
+        TEST(CacheKeySerializerTests, LiteralStrings) {
+            // Using a std::string here to help with creating the expected result.
+            std::string str = "string";
+
+            CacheKey expected;
+            expected.Record(size_t(7));
+            expected.insert(expected.end(), str.begin(), str.end());
+            expected.push_back('\0');
+
+            EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected));
+        }
+
+        TEST(CacheKeySerializerTests, StdStrings) {
+            std::string str = "string";
+
+            CacheKey expected;
+            expected.Record((size_t)6);
+            expected.insert(expected.end(), str.begin(), str.end());
+
+            EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
+        }
+
+        TEST(CacheKeySerializerTests, CacheKeys) {
+            CacheKey data = {'d', 'a', 't', 'a'};
+
+            CacheKey expected;
+            expected.insert(expected.end(), data.begin(), data.end());
+
+            EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected));
         }
 
     }  // namespace