Skip to content

Commit 1c7f998

Browse files
committed
[REFACTPR] Update the MapObj to be generic and allow subclass
1 parent 308756e commit 1c7f998

4 files changed

Lines changed: 227 additions & 271 deletions

File tree

include/tvm/ffi/container/container_details.h

Lines changed: 0 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -36,134 +36,6 @@
3636
namespace tvm {
3737
namespace ffi {
3838
namespace details {
39-
/*!
40-
* \brief Base template for classes with array like memory layout.
41-
*
42-
* It provides general methods to access the memory. The memory
43-
* layout is ArrayType + [ElemType]. The alignment of ArrayType
44-
* and ElemType is handled by the memory allocator.
45-
*
46-
* \tparam ArrayType The array header type, contains object specific metadata.
47-
* \tparam ElemType The type of objects stored in the array right after
48-
* ArrayType.
49-
*
50-
* \code{.cpp}
51-
* // Example usage of the template to define a simple array wrapper
52-
* class ArrayObj : public tvm::ffi::details::InplaceArrayBase<ArrayObj, Elem> {
53-
* public:
54-
* // Wrap EmplaceInit to initialize the elements
55-
* template <typename Iterator>
56-
* void Init(Iterator begin, Iterator end) {
57-
* size_t num_elems = std::distance(begin, end);
58-
* auto it = begin;
59-
* this->size = 0;
60-
* for (size_t i = 0; i < num_elems; ++i) {
61-
* InplaceArrayBase::EmplaceInit(i, *it++);
62-
* this->size++;
63-
* }
64-
* }
65-
* }
66-
*
67-
* void test_function() {
68-
* vector<Elem> fields;
69-
* auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size());
70-
* ptr->Init(fields.begin(), fields.end());
71-
*
72-
* // Access the 0th element in the array.
73-
* assert(ptr->operator[](0) == fields[0]);
74-
* }
75-
* \endcode
76-
*/
77-
template <typename ArrayType, typename ElemType>
78-
class InplaceArrayBase {
79-
public:
80-
/*!
81-
* \brief Access element at index
82-
* \param idx The index of the element.
83-
* \return Const reference to ElemType at the index.
84-
*/
85-
const ElemType& operator[](size_t idx) const {
86-
size_t size = Self()->GetSize();
87-
if (idx > size) {
88-
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size;
89-
}
90-
return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
91-
}
92-
93-
/*!
94-
* \brief Access element at index
95-
* \param idx The index of the element.
96-
* \return Reference to ElemType at the index.
97-
*/
98-
ElemType& operator[](size_t idx) {
99-
size_t size = Self()->GetSize();
100-
if (idx > size) {
101-
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size;
102-
}
103-
return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
104-
}
105-
106-
/*!
107-
* \brief Destroy the Inplace Array Base object
108-
*/
109-
~InplaceArrayBase() {
110-
if constexpr (!(std::is_standard_layout_v<ElemType> && std::is_trivial_v<ElemType>)) {
111-
size_t size = Self()->GetSize();
112-
for (size_t i = 0; i < size; ++i) {
113-
ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
114-
fp->ElemType::~ElemType();
115-
}
116-
}
117-
}
118-
119-
private:
120-
InplaceArrayBase() = default;
121-
friend ArrayType;
122-
123-
protected:
124-
/*!
125-
* \brief Construct a value in place with the arguments.
126-
*
127-
* \tparam Args Type parameters of the arguments.
128-
* \param idx Index of the element.
129-
* \param args Arguments to construct the new value.
130-
*
131-
* \note Please make sure ArrayType::GetSize returns 0 before first call of
132-
* EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds.
133-
*/
134-
template <typename... Args>
135-
void EmplaceInit(size_t idx, Args&&... args) {
136-
void* field_ptr = AddressOf(idx);
137-
new (field_ptr) ElemType(std::forward<Args>(args)...);
138-
}
139-
140-
/*!
141-
* \brief Return the self object for the array.
142-
*
143-
* \return Pointer to ArrayType.
144-
*/
145-
inline ArrayType* Self() const {
146-
return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this));
147-
}
148-
149-
/*!
150-
* \brief Return the raw pointer to the element at idx.
151-
*
152-
* \param idx The index of the element.
153-
* \return Raw pointer to the element.
154-
*/
155-
void* AddressOf(size_t idx) const {
156-
static_assert(
157-
alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0,
158-
"The size and alignment of ArrayType should respect "
159-
"ElemType's alignment.");
160-
161-
size_t kDataStart = sizeof(ArrayType);
162-
ArrayType* self = Self();
163-
char* data_start = reinterpret_cast<char*>(self) + kDataStart;
164-
return data_start + idx * sizeof(ElemType);
165-
}
166-
};
16739

16840
/*!
16941
* \brief iterator adapter that adapts TIter to return another type.

include/tvm/ffi/container/map.h

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,25 @@
3131
#include <tvm/ffi/object.h>
3232
#include <tvm/ffi/optional.h>
3333

34-
#include <type_traits>
3534
#include <unordered_map>
3635

3736
namespace tvm {
3837
namespace ffi {
3938

39+
/*! \brief Map object */
40+
class MapObj : public MapBaseObj {
41+
public:
42+
/// \cond Doxygen_Suppress
43+
static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap;
44+
static const constexpr bool _type_final = true;
45+
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object);
46+
/// \endcond
47+
48+
protected:
49+
template <typename, typename, typename>
50+
friend class Map;
51+
};
52+
4053
/*!
4154
* \brief Map container of NodeRef->NodeRef in DSL graph.
4255
* Map implements copy on write semantics, which means map is mutable
@@ -64,7 +77,7 @@ class Map : public ObjectRef {
6477
/*!
6578
* \brief default constructor
6679
*/
67-
Map() { data_ = MapObj::Empty(); }
80+
Map() { data_ = MapObj::Empty<MapObj>(); }
6881
/*!
6982
* \brief move constructor
7083
* \param other source
@@ -159,22 +172,22 @@ class Map : public ObjectRef {
159172
*/
160173
template <typename IterType>
161174
Map(IterType begin, IterType end) {
162-
data_ = MapObj::CreateFromRange(begin, end);
175+
data_ = MapObj::CreateFromRange<MapObj>(begin, end);
163176
}
164177
/*!
165178
* \brief constructor from initializer list
166179
* \param init The initalizer list
167180
*/
168181
Map(std::initializer_list<std::pair<K, V>> init) {
169-
data_ = MapObj::CreateFromRange(init.begin(), init.end());
182+
data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
170183
}
171184
/*!
172185
* \brief constructor from unordered_map
173186
* \param init The unordered_map
174187
*/
175188
template <typename Hash, typename Equal>
176189
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
177-
data_ = MapObj::CreateFromRange(init.begin(), init.end());
190+
data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
178191
}
179192
/*!
180193
* \brief Read element from map.
@@ -206,7 +219,7 @@ class Map : public ObjectRef {
206219
void clear() {
207220
MapObj* n = GetMapObj();
208221
if (n != nullptr) {
209-
data_ = MapObj::Empty();
222+
data_ = MapObj::Empty<MapObj>();
210223
}
211224
}
212225
/*!
@@ -216,7 +229,7 @@ class Map : public ObjectRef {
216229
*/
217230
void Set(const K& key, const V& value) {
218231
CopyOnWrite();
219-
MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_);
232+
MapObj::InsertMaybeReHash<MapObj>(MapObj::KVType(key, value), &data_);
220233
}
221234
/*! \return begin iterator */
222235
iterator begin() const { return iterator(GetMapObj()->begin()); }
@@ -249,9 +262,9 @@ class Map : public ObjectRef {
249262
*/
250263
MapObj* CopyOnWrite() {
251264
if (data_.get() == nullptr) {
252-
data_ = MapObj::Empty();
265+
data_ = MapObj::Empty<MapObj>();
253266
} else if (!data_.unique()) {
254-
data_ = MapObj::CopyFrom(GetMapObj());
267+
data_ = MapObj::CopyFrom<MapObj>(GetMapObj());
255268
}
256269
return GetMapObj();
257270
}

0 commit comments

Comments
 (0)