Skip to content

Commit 81e17b7

Browse files
committed
Make it possible for UniquePtr and SharedPtr to be constructed and host derived instances of base classes
1 parent 17f2e3f commit 81e17b7

4 files changed

Lines changed: 162 additions & 23 deletions

File tree

Tests/Runner.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,4 +1796,47 @@ TEST(ARLibTests, BigIntOrderingTests) {
17961796
EXPECT_EQ(vec[0], a);
17971797
EXPECT_EQ(vec[1], c);
17981798
EXPECT_EQ(vec[2], b);
1799+
}
1800+
TEST(ARLibTests, UniquePtrDerivedBaseTests) {
1801+
struct Base {
1802+
virtual ~Base() = default;
1803+
virtual StringView name() const { return "Base"_sv; }
1804+
};
1805+
struct Derived : public Base {
1806+
StringView name() const override { return "Derived"_sv; }
1807+
bool derived_only() const { return true; }
1808+
};
1809+
UniquePtr<Base> ptr{ new Derived };
1810+
UniquePtr<Base> ptr2 = UniquePtr<Derived>(new Derived);
1811+
EXPECT_EQ(ptr->name(), "Derived"_sv);
1812+
EXPECT_EQ(ptr2->name(), "Derived"_sv);
1813+
EXPECT_TRUE(ptr.get<Derived>()->derived_only());
1814+
EXPECT_TRUE(ptr2.get<Derived>()->derived_only());
1815+
EXPECT_TRUE(ptr.as<Derived>().derived_only());
1816+
EXPECT_TRUE(ptr2.as<Derived>().derived_only());
1817+
}
1818+
TEST(ARLibTests, SharedPtrDerivedBaseTests) {
1819+
struct Base {
1820+
virtual ~Base() = default;
1821+
virtual StringView name() const { return "Base"_sv; }
1822+
};
1823+
struct Derived : public Base {
1824+
StringView name() const override { return "Derived"_sv; }
1825+
bool derived_only() const { return true; }
1826+
};
1827+
SharedPtr<Base> empty{};
1828+
{
1829+
SharedPtr<Base> ptr{ new Derived };
1830+
SharedPtr<Base> ptr2 = SharedPtr<Derived>(new Derived);
1831+
SharedPtr<Base> ptr3 = ptr;
1832+
EXPECT_EQ(ptr->name(), "Derived"_sv);
1833+
EXPECT_EQ(ptr2->name(), "Derived"_sv);
1834+
EXPECT_EQ(ptr3->name(), "Derived"_sv);
1835+
EXPECT_TRUE(ptr.get<Derived>()->derived_only());
1836+
EXPECT_TRUE(ptr2.get<Derived>()->derived_only());
1837+
EXPECT_TRUE(ptr.as<Derived>().derived_only());
1838+
EXPECT_TRUE(ptr2.as<Derived>().derived_only());
1839+
empty = ptr3;
1840+
}
1841+
EXPECT_EQ(empty->name(), "Derived"_sv);
17991842
}

include/SharedPtr.hpp

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
namespace ARLib {
99
template <typename T>
1010
class SharedPtr {
11-
T* m_storage = nullptr;
12-
RefCountBase<T>* m_count = nullptr;
11+
T* m_storage = nullptr;
12+
RefCountBase<>* m_count = nullptr;
13+
14+
template <typename U>
15+
friend class SharedPtr;
1316
void decrease_instance_count_() {
1417
if (m_count == nullptr) return;
15-
m_count->decref();
18+
m_count->decref<T>();
1619
if (m_count->count() == 0) {
1720
delete m_count;
1821
m_count = nullptr;
1922
}
2023
}
21-
22-
SharedPtr(WeakPtr<T>& weak) {
24+
SharedPtr(WeakPtr<T>& weak) {
2325
m_storage = weak.m_storage;
2426
m_count = weak.m_count;
2527
m_count->incref();
@@ -39,24 +41,59 @@ class SharedPtr {
3941
other.m_count = nullptr;
4042
return *this;
4143
}
44+
template <DerivedFrom<T> U>
45+
SharedPtr(SharedPtr<U>&& other) noexcept : m_storage(other.m_storage), m_count(other.m_count) {
46+
other.m_storage = nullptr;
47+
other.m_count = nullptr;
48+
}
49+
template <DerivedFrom<T> U>
50+
SharedPtr& operator=(SharedPtr<U>&& other) noexcept {
51+
decrease_instance_count_();
52+
m_storage = other.m_storage;
53+
m_count = other.m_count;
54+
other.m_storage = nullptr;
55+
other.m_count = nullptr;
56+
return *this;
57+
}
4258
SharedPtr(nullptr_t) = delete;
43-
SharedPtr(T* ptr) : m_storage(ptr), m_count(new RefCountBase<T>{ m_storage }) {
59+
SharedPtr(T* ptr) : m_storage(ptr), m_count(new RefCountBase<>{ m_storage }) {
4460
HARD_ASSERT(ptr, "Pointer passed to SharedPtr must not be null");
4561
}
4662
SharedPtr(T&& storage) {
4763
m_storage = new T{ move(storage) };
48-
m_count = new RefCountBase<T>{ m_storage };
64+
m_count = new RefCountBase<>{ m_storage };
65+
}
66+
template <DerivedFrom<T> U>
67+
SharedPtr(U* ptr) : m_storage(ptr), m_count(new RefCountBase<>{ m_storage }) {
68+
HARD_ASSERT(ptr, "Pointer passed to SharedPtr must not be null");
69+
}
70+
template <DerivedFrom<T> U>
71+
SharedPtr(U&& storage) {
72+
m_storage = new U{ move(storage) };
73+
m_count = new RefCountBase<>{ m_storage };
4974
}
5075
SharedPtr(const SharedPtr& other) {
5176
m_storage = other.m_storage;
5277
m_count = other.m_count;
5378
if (m_storage == nullptr && m_count == nullptr) return;
5479
m_count->incref();
5580
}
81+
template <DerivedFrom<T> U>
82+
SharedPtr(const SharedPtr<U>& other) {
83+
m_storage = other.m_storage;
84+
m_count = other.m_count;
85+
if (m_storage == nullptr && m_count == nullptr) return;
86+
m_count->incref();
87+
}
5688
template <typename... Args>
5789
SharedPtr(EmplaceT<T>, Args&&... args) {
5890
m_storage = new T{ Forward<Args>(args)... };
59-
m_count = new RefCountBase<T>{ m_storage };
91+
m_count = new RefCountBase<>{ m_storage };
92+
}
93+
template <DerivedFrom<T> U, typename... Args>
94+
SharedPtr(EmplaceT<U>, Args&&... args) {
95+
m_storage = new U{ Forward<Args>(args)... };
96+
m_count = new RefCountBase<>{ m_storage };
6097
}
6198
SharedPtr& operator=(const SharedPtr& other) {
6299
if (this == &other) return *this;
@@ -67,10 +104,20 @@ class SharedPtr {
67104
m_count->incref();
68105
return *this;
69106
}
107+
template <DerivedFrom<T> U>
108+
SharedPtr& operator=(const SharedPtr<U>& other) {
109+
if (this == &other) return *this;
110+
reset();
111+
m_storage = other.m_storage;
112+
m_count = other.m_count;
113+
if (m_storage == nullptr || m_count == nullptr) return *this;
114+
m_count->incref();
115+
return *this;
116+
}
70117
bool operator==(const SharedPtr& other) const { return m_storage == other.m_storage; }
71118
bool operator==(const T* other_ptr) const { return m_storage == other_ptr; }
72119
T* release() {
73-
T* ptr = m_count->release_storage();
120+
T* ptr = m_count->release_storage<T>();
74121
decrease_instance_count_();
75122
m_count = nullptr;
76123
m_storage = nullptr;
@@ -89,6 +136,22 @@ class SharedPtr {
89136
WeakPtr<T> weakptr() const { return WeakPtr{ m_storage, m_count }; }
90137
T* get() { return m_storage; }
91138
const T* get() const { return m_storage; }
139+
template <DerivedFrom<T> U>
140+
U* get() {
141+
return static_cast<U*>(m_storage);
142+
}
143+
template <DerivedFrom<T> U>
144+
const U* get() const {
145+
return static_cast<const T*>(m_storage);
146+
}
147+
template <DerivedFrom<T> U>
148+
U& as() {
149+
return *get<U>();
150+
}
151+
template <DerivedFrom<T> U>
152+
const U& as() const {
153+
return *get<U>();
154+
}
92155
auto refcount() const { return m_count ? m_count->count() : 0ul; }
93156
bool exists() const { return m_storage != nullptr; }
94157
T* operator->() { return m_storage; }
@@ -97,21 +160,19 @@ class SharedPtr {
97160
const T& operator*() const { return *m_storage; }
98161
~SharedPtr() { decrease_instance_count_(); }
99162
};
100-
101163
template <typename T>
102164
SharedPtr<T> WeakPtr<T>::lock() {
103165
return SharedPtr{ *this };
104166
}
105-
106167
template <typename T>
107168
class SharedPtr<T[]> {
108-
using RefCount = RefCountBase<T, true>;
169+
using RefCount = RefCountBase<true>;
109170
T* m_storage = nullptr;
110171
RefCount* m_count = nullptr;
111172
size_t m_size = 0ull;
112173
void decrease_instance_count_() {
113174
if (m_count == nullptr) return;
114-
m_count->decref();
175+
m_count->decref<T>();
115176
if (m_count->count() == 0) {
116177
delete m_count;
117178
m_count = nullptr;
@@ -154,7 +215,7 @@ class SharedPtr<T[]> {
154215
}
155216
bool operator==(const SharedPtr& other) const { return m_storage == other.m_storage; }
156217
T* release() {
157-
T* ptr = m_count->release_storage();
218+
T* ptr = m_count->release_storage<T>();
158219
decrease_instance_count_();
159220
m_count = nullptr;
160221
m_storage = nullptr;

include/UniquePtr.hpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,27 @@ class UniquePtr {
1414
UniquePtr(const UniquePtr&) = delete;
1515
explicit UniquePtr(T* ptr) : m_storage(ptr) {}
1616
explicit UniquePtr(T&& storage) : m_storage(new T{ move(storage) }) {}
17+
template <DerivedFrom<T> U>
18+
explicit UniquePtr(U* ptr) : m_storage(ptr) {}
19+
template <DerivedFrom<T> U>
20+
explicit UniquePtr(U&& storage) : m_storage(new U{ move(storage) }) {}
1721
UniquePtr(UniquePtr&& ptr) noexcept {
1822
reset();
1923
m_storage = ptr.release();
2024
}
25+
template <DerivedFrom<T> U>
26+
UniquePtr(UniquePtr<U>&& ptr) noexcept {
27+
reset();
28+
m_storage = ptr.release();
29+
}
2130
template <typename... Args>
2231
explicit UniquePtr(EmplaceT<T>, Args&&... args) {
2332
m_storage = new T{ Forward<Args>(args)... };
2433
}
34+
template <DerivedFrom<T> U, typename... Args>
35+
explicit UniquePtr(EmplaceT<U>, Args&&... args) {
36+
m_storage = new U{ Forward<Args>(args)... };
37+
}
2538
UniquePtr& operator=(UniquePtr&& other) noexcept {
2639
reset();
2740
m_storage = other.release();
@@ -47,7 +60,22 @@ class UniquePtr {
4760
T& operator*() & { return *m_storage; }
4861
const T& operator*() const& { return *m_storage; }
4962
T&& operator*() && = delete;
50-
63+
template <DerivedFrom<T> U>
64+
U* get() {
65+
return static_cast<U*>(m_storage);
66+
}
67+
template <DerivedFrom<T> U>
68+
const U* get() const {
69+
return static_cast<const U*>(m_storage);
70+
}
71+
template <DerivedFrom<T> U>
72+
U& as() {
73+
return *get<U>();
74+
}
75+
template <DerivedFrom<T> U>
76+
const U& get() const {
77+
return *get<U>();
78+
}
5179
// this is operator* which doesn't cause object slicing when moving outside of the UniquePtr
5280
template <DerivedFrom<T> U = T>
5381
U moved() && {

include/WeakPtr.hpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,39 @@
1212
#define SYNC_DEC(x) __sync_sub_and_fetch(x, 1)
1313
#endif
1414
namespace ARLib {
15-
template <typename T, bool Multiple = false>
15+
template <bool Multiple = false>
1616
class RefCountBase {
1717
unsigned long m_counter = 1;
1818
unsigned long m_weak_refs = 0;
19-
T* m_object = nullptr;
19+
void* m_object = nullptr;
20+
21+
template <typename T>
2022
void destroy() noexcept {
23+
T* mem = static_cast<T*>(m_object);
2124
if constexpr (Multiple) {
22-
delete[] m_object;
25+
delete[] mem;
2326
} else {
24-
delete m_object;
27+
delete mem;
2528
}
2629
}
2730

2831
public:
2932
constexpr RefCountBase() noexcept = default;
33+
template <typename T>
3034
explicit RefCountBase(T* object) : m_object(object) {}
35+
explicit RefCountBase(void* object) : m_object(object) {}
3136
RefCountBase(const RefCountBase&) = delete;
3237
RefCountBase& operator=(const RefCountBase&) = delete;
3338
void incref() noexcept { SYNC_INC(cast<volatile long*>(&m_counter)); }
3439
void incweakref() noexcept { SYNC_INC(cast<volatile long*>(&m_weak_refs)); }
40+
template <typename T>
3541
void decref() noexcept {
36-
if (SYNC_DEC(cast<volatile long*>(&m_counter)) == 0) { destroy(); }
42+
if (SYNC_DEC(cast<volatile long*>(&m_counter)) == 0) { destroy<T>(); }
3743
}
3844
void decweakref() noexcept { SYNC_DEC(cast<volatile long*>(&m_weak_refs)); }
45+
template <typename T>
3946
T* release_storage() {
40-
T* ptr = m_object;
47+
T* ptr = static_cast<T>(m_object);
4148
m_object = nullptr;
4249
return ptr;
4350
}
@@ -50,9 +57,9 @@ class SharedPtr;
5057
template <typename T>
5158
class WeakPtr {
5259
T* m_storage = nullptr;
53-
RefCountBase<T>* m_count = nullptr;
60+
RefCountBase<>* m_count = nullptr;
5461
friend SharedPtr<T>;
55-
WeakPtr(T* storage_ptr, RefCountBase<T>* count) : m_storage(storage_ptr), m_count(count) { m_count->incweakref(); }
62+
WeakPtr(T* storage_ptr, RefCountBase<>* count) : m_storage(storage_ptr), m_count(count) { m_count->incweakref(); }
5663
void decrease_instance_count_() {
5764
if (m_count == nullptr) return;
5865
m_count->decweakref();

0 commit comments

Comments
 (0)