blob: 179710256890b886ba545e397d7c41905d23f287 [file] [log] [blame]
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_STREAM_STREAM_H_
#define SRC_DAWN_NATIVE_STREAM_STREAM_H_
#include <algorithm>
#include <bitset>
#include <functional>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dawn/common/Platform.h"
#include "dawn/common/TypedInteger.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/Error.h"
namespace dawn::native {
class CacheKey;
// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
template <typename T>
class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
public:
constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
};
// Specialized overload for fundamental types.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
public:
static void Serialize(CacheKey* key, const T t) {
const char* it = reinterpret_cast<const char*>(&t);
key->insert(key->end(), it, (it + sizeof(T)));
}
};
// Specialized overload for bitsets that are smaller than 64.
template <size_t N>
class CacheKeySerializer<std::bitset<N>, std::enable_if_t<(N <= 64)>> {
public:
static void Serialize(CacheKey* key, const std::bitset<N>& t) { key->Record(t.to_ullong()); }
};
// Specialized overload for bitsets since using the built-in to_ullong have a size limit.
template <size_t N>
class CacheKeySerializer<std::bitset<N>, std::enable_if_t<(N > 64)>> {
public:
static void Serialize(CacheKey* key, const std::bitset<N>& t) {
// Serializes the bitset into series of uint8_t, along with recording the size.
static_assert(N > 0);
key->Record(static_cast<size_t>(N));
uint8_t value = 0;
for (size_t i = 0; i < N; i++) {
value <<= 1;
// Explicitly convert to numeric since MSVC doesn't like mixing of bools.
value |= t[i] ? 1 : 0;
if (i % 8 == 7) {
// Whenever we fill an 8 bit value, record it and zero it out.
key->Record(value);
value = 0;
}
}
// Serialize the last value if we are not a multiple of 8.
if (N % 8 != 0) {
key->Record(value);
}
}
};
// Specialized overload for enums.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_enum_v<T>>> {
public:
static void Serialize(CacheKey* key, const T t) {
CacheKeySerializer<std::underlying_type_t<T>>::Serialize(
key, static_cast<std::underlying_type_t<T>>(t));
}
};
// Specialized overload for TypedInteger.
template <typename Tag, typename Integer>
class CacheKeySerializer<::detail::TypedIntegerImpl<Tag, Integer>> {
public:
static void Serialize(CacheKey* key, const ::detail::TypedIntegerImpl<Tag, Integer> t) {
CacheKeySerializer<Integer>::Serialize(key, static_cast<Integer>(t));
}
};
// Specialized overload for pointers. Since we are serializing for a cache key, we always
// serialize via value, not by pointer. To handle nullptr scenarios, we always serialize whether
// the pointer was nullptr followed by the contents if applicable.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_pointer_v<T>>> {
public:
static void Serialize(CacheKey* key, const T t) {
key->Record(t == nullptr);
if (t != nullptr) {
CacheKeySerializer<std::remove_cv_t<std::remove_pointer_t<T>>>::Serialize(key, *t);
}
}
};
// Specialized overload for fixed arrays of primitives.
template <typename T, size_t N>
class CacheKeySerializer<T[N], std::enable_if_t<std::is_fundamental_v<T>>> {
public:
static void Serialize(CacheKey* key, const T (&t)[N]) {
static_assert(N > 0);
key->Record(static_cast<size_t>(N));
const char* it = reinterpret_cast<const char*>(t);
key->insert(key->end(), it, it + sizeof(t));
}
};
// Specialized overload for fixed arrays of non-primitives.
template <typename T, size_t N>
class CacheKeySerializer<T[N], std::enable_if_t<!std::is_fundamental_v<T>>> {
public:
static void Serialize(CacheKey* key, const T (&t)[N]) {
static_assert(N > 0);
key->Record(static_cast<size_t>(N));
for (size_t i = 0; i < N; i++) {
key->Record(t[i]);
}
}
};
// Specialized overload for CachedObjects.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>> {
public:
static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
};
// Specialized overload for std::vector.
template <typename T>
class CacheKeySerializer<std::vector<T>> {
public:
static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
};
// Specialized overload for std::pair<A, B>
template <typename A, typename B>
class CacheKeySerializer<std::pair<A, B>> {
public:
static void Serialize(CacheKey* key, const std::pair<A, B>& p) {
key->Record(p.first, p.second);
}
};
// Specialized overload for std::unordered_map<K, V>
template <typename K, typename V>
class CacheKeySerializer<std::unordered_map<K, V>> {
public:
static void Serialize(CacheKey* key, const std::unordered_map<K, V>& m) {
std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
std::sort(ordered.begin(), ordered.end(),
[](const std::pair<K, V>& a, const std::pair<K, V>& b) {
return std::less<K>{}(a.first, b.first);
});
key->RecordIterable(ordered);
}
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_STREAM_STREAM_H_