Add Ref<T> specialization for Result

Ref<T> specialization will allow us to, in a future change, return
Result<Ref<T>> instances from Create methods while still keeping the
tagged pointer optimization.

Change-Id: I20c764358af22ba1dc53458d59b0b2b4770a0c6a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/19801
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Rafael Cintron <rafael.cintron@microsoft.com>
diff --git a/src/common/Result.h b/src/common/Result.h
index 7dc5e2d..8509568 100644
--- a/src/common/Result.h
+++ b/src/common/Result.h
@@ -166,6 +166,35 @@
     intptr_t mPayload = detail::kEmptyPayload;
 };
 
+template <typename T>
+class Ref;
+
+template <typename T, typename E>
+class DAWN_NO_DISCARD Result<Ref<T>, E> {
+  public:
+    static_assert(alignof_if_defined_else_default<T, 4> >= 4,
+                  "Result<Ref<T>, E> reserves two bits for tagging pointers");
+    static_assert(alignof_if_defined_else_default<E, 4> >= 4,
+                  "Result<Ref<T>, E> reserves two bits for tagging pointers");
+
+    Result(Ref<T>&& success);
+    Result(std::unique_ptr<E> error);
+
+    Result(Result<Ref<T>, E>&& other);
+    Result<Ref<T>, E>& operator=(Result<Ref<T>, E>&& other);
+
+    ~Result();
+
+    bool IsError() const;
+    bool IsSuccess() const;
+
+    Ref<T> AcquireSuccess();
+    std::unique_ptr<E> AcquireError();
+
+  private:
+    intptr_t mPayload = detail::kEmptyPayload;
+};
+
 // Catchall definition of Result<T, E> implemented as a tagged struct. It could be improved to use
 // a tagged union instead if it turns out to be a hotspot. T and E must be movable and default
 // constructible.
@@ -368,6 +397,61 @@
     return std::move(error);
 }
 
+// Implementation of Result<Ref<T>, E>
+template <typename T, typename E>
+Result<Ref<T>, E>::Result(Ref<T>&& success)
+    : mPayload(detail::MakePayload(success.Detach(), detail::Success)) {
+}
+
+template <typename T, typename E>
+Result<Ref<T>, E>::Result(std::unique_ptr<E> error)
+    : mPayload(detail::MakePayload(error.release(), detail::Error)) {
+}
+
+template <typename T, typename E>
+Result<Ref<T>, E>::Result(Result<Ref<T>, E>&& other) : mPayload(other.mPayload) {
+    other.mPayload = detail::kEmptyPayload;
+}
+
+template <typename T, typename E>
+Result<Ref<T>, E>& Result<Ref<T>, E>::operator=(Result<Ref<T>, E>&& other) {
+    ASSERT(mPayload == detail::kEmptyPayload);
+    mPayload = other.mPayload;
+    other.mPayload = detail::kEmptyPayload;
+    return *this;
+}
+
+template <typename T, typename E>
+Result<Ref<T>, E>::~Result() {
+    ASSERT(mPayload == detail::kEmptyPayload);
+}
+
+template <typename T, typename E>
+bool Result<Ref<T>, E>::IsError() const {
+    return detail::GetPayloadType(mPayload) == detail::Error;
+}
+
+template <typename T, typename E>
+bool Result<Ref<T>, E>::IsSuccess() const {
+    return detail::GetPayloadType(mPayload) == detail::Success;
+}
+
+template <typename T, typename E>
+Ref<T> Result<Ref<T>, E>::AcquireSuccess() {
+    ASSERT(IsSuccess());
+    Ref<T> success = AcquireRef(detail::GetSuccessFromPayload<T>(mPayload));
+    mPayload = detail::kEmptyPayload;
+    return success;
+}
+
+template <typename T, typename E>
+std::unique_ptr<E> Result<Ref<T>, E>::AcquireError() {
+    ASSERT(IsError());
+    std::unique_ptr<E> error(detail::GetErrorFromPayload<E>(mPayload));
+    mPayload = detail::kEmptyPayload;
+    return std::move(error);
+}
+
 // Implementation of Result<T, E>
 template <typename T, typename E>
 Result<T, E>::Result(T&& success) : mType(Success), mSuccess(std::move(success)) {
diff --git a/src/tests/unittests/ResultTests.cpp b/src/tests/unittests/ResultTests.cpp
index dd87e22..8e2f05f 100644
--- a/src/tests/unittests/ResultTests.cpp
+++ b/src/tests/unittests/ResultTests.cpp
@@ -14,32 +14,71 @@
 
 #include <gtest/gtest.h>
 
+#include "common/RefCounted.h"
 #include "common/Result.h"
 
 namespace {
 
 template<typename T, typename E>
 void TestError(Result<T, E>* result, E expectedError) {
-    ASSERT_TRUE(result->IsError());
-    ASSERT_FALSE(result->IsSuccess());
+    EXPECT_TRUE(result->IsError());
+    EXPECT_FALSE(result->IsSuccess());
 
     std::unique_ptr<E> storedError = result->AcquireError();
-    ASSERT_EQ(*storedError, expectedError);
+    EXPECT_EQ(*storedError, expectedError);
 }
 
 template<typename T, typename E>
 void TestSuccess(Result<T, E>* result, T expectedSuccess) {
-    ASSERT_FALSE(result->IsError());
-    ASSERT_TRUE(result->IsSuccess());
+    EXPECT_FALSE(result->IsError());
+    EXPECT_TRUE(result->IsSuccess());
 
-    T storedSuccess = result->AcquireSuccess();
-    ASSERT_EQ(storedSuccess, expectedSuccess);
+    const T storedSuccess = result->AcquireSuccess();
+    EXPECT_EQ(storedSuccess, expectedSuccess);
+
+    // Once the success is acquired, result has an empty
+    // payload and is neither in the success nor error state.
+    EXPECT_FALSE(result->IsError());
+    EXPECT_FALSE(result->IsSuccess());
 }
 
 static int dummyError = 0xbeef;
 static float dummySuccess = 42.0f;
 static const float dummyConstSuccess = 42.0f;
 
+class AClass : public RefCounted {
+  public:
+    int a = 0;
+};
+
+// Tests using the following overload of TestSuccess make
+// local Ref instances to dummySuccessObj. Tests should
+// ensure any local Ref objects made along the way continue
+// to point to dummySuccessObj.
+template <typename T, typename E>
+void TestSuccess(Result<Ref<T>, E>* result, T* expectedSuccess) {
+    EXPECT_FALSE(result->IsError());
+    EXPECT_TRUE(result->IsSuccess());
+
+    // AClass starts with a reference count of 1 and stored
+    // on the stack in the caller. The result parameter should
+    // hold the only other reference to the object.
+    EXPECT_EQ(expectedSuccess->GetRefCountForTesting(), 2u);
+
+    const Ref<T> storedSuccess = result->AcquireSuccess();
+    EXPECT_EQ(storedSuccess.Get(), expectedSuccess);
+
+    // Once the success is acquired, result has an empty
+    // payload and is neither in the success nor error state.
+    EXPECT_FALSE(result->IsError());
+    EXPECT_FALSE(result->IsSuccess());
+
+    // Once we call AcquireSuccess, result no longer stores
+    // the object. storedSuccess should contain the only other
+    // reference to the object.
+    EXPECT_EQ(storedSuccess->GetRefCountForTesting(), 2u);
+}
+
 // Result<void, E*>
 
 // Test constructing an error Result<void, E>
@@ -66,16 +105,16 @@
 // Test constructing a success Result<void, E>
 TEST(ResultOnlyPointerError, ConstructingSuccess) {
     Result<void, int> result;
-    ASSERT_TRUE(result.IsSuccess());
-    ASSERT_FALSE(result.IsError());
+    EXPECT_TRUE(result.IsSuccess());
+    EXPECT_FALSE(result.IsError());
 }
 
 // Test moving a success Result<void, E>
 TEST(ResultOnlyPointerError, MovingSuccess) {
     Result<void, int> result;
     Result<void, int> movedResult(std::move(result));
-    ASSERT_TRUE(movedResult.IsSuccess());
-    ASSERT_FALSE(movedResult.IsError());
+    EXPECT_TRUE(movedResult.IsSuccess());
+    EXPECT_FALSE(movedResult.IsError());
 }
 
 // Test returning a success Result<void, E>
@@ -83,8 +122,8 @@
     auto CreateError = []() -> Result<void, int> { return {}; };
 
     Result<void, int> result = CreateError();
-    ASSERT_TRUE(result.IsSuccess());
-    ASSERT_FALSE(result.IsError());
+    EXPECT_TRUE(result.IsSuccess());
+    EXPECT_FALSE(result.IsError());
 }
 
 // Result<T*, E*>
@@ -204,6 +243,59 @@
     TestSuccess(&result, &dummyConstSuccess);
 }
 
+// Result<Ref<T>, E>
+
+// Test constructing an error Result<Ref<T>, E>
+TEST(ResultRefT, ConstructingError) {
+    Result<Ref<AClass>, int> result(std::make_unique<int>(dummyError));
+    TestError(&result, dummyError);
+}
+
+// Test moving an error Result<Ref<T>, E>
+TEST(ResultRefT, MovingError) {
+    Result<Ref<AClass>, int> result(std::make_unique<int>(dummyError));
+    Result<Ref<AClass>, int> movedResult(std::move(result));
+    TestError(&movedResult, dummyError);
+}
+
+// Test returning an error Result<Ref<T>, E>
+TEST(ResultRefT, ReturningError) {
+    auto CreateError = []() -> Result<Ref<AClass>, int> {
+        return {std::make_unique<int>(dummyError)};
+    };
+
+    Result<Ref<AClass>, int> result = CreateError();
+    TestError(&result, dummyError);
+}
+
+// Test constructing a success Result<Ref<T>, E>
+TEST(ResultRefT, ConstructingSuccess) {
+    AClass success;
+
+    Ref<AClass> refObj(&success);
+    Result<Ref<AClass>, int> result(std::move(refObj));
+    TestSuccess(&result, &success);
+}
+
+// Test moving a success Result<Ref<T>, E>
+TEST(ResultRefT, MovingSuccess) {
+    AClass success;
+
+    Ref<AClass> refObj(&success);
+    Result<Ref<AClass>, int> result(std::move(refObj));
+    Result<Ref<AClass>, int> movedResult(std::move(result));
+    TestSuccess(&movedResult, &success);
+}
+
+// Test returning a success Result<Ref<T>, E>
+TEST(ResultRefT, ReturningSuccess) {
+    AClass success;
+    auto CreateSuccess = [&success]() -> Result<Ref<AClass>, int> { return Ref<AClass>(&success); };
+
+    Result<Ref<AClass>, int> result = CreateSuccess();
+    TestSuccess(&result, &success);
+}
+
 // Result<T, E>
 
 // Test constructing an error Result<T, E>