-
Notifications
You must be signed in to change notification settings - Fork 897
Expand file tree
/
Copy pathpybind_utils.h
More file actions
120 lines (107 loc) · 3.89 KB
/
pybind_utils.h
File metadata and controls
120 lines (107 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#pragma once
#include <map/shot.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#ifdef PYBIND11_NAMESPACE_BEGIN
#define PYBIND11_NAMESPACE_BEGIN_ PYBIND11_NAMESPACE_BEGIN
#define PYBIND11_NAMESPACE_END_ PYBIND11_NAMESPACE_END
#else
#define PYBIND11_NAMESPACE_BEGIN_ NAMESPACE_BEGIN
#define PYBIND11_NAMESPACE_END_ NAMESPACE_END
#endif
PYBIND11_NAMESPACE_BEGIN_(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN_(detail)
// See https://github.com/pybind/pybind11/issues/637
// Also fbsource/fbcode/caffe2/torch/csrc/jit/python/pybind.h
using ListCasterBase =
pybind11::detail::list_caster<std::vector<map::Landmark *>,
map::Landmark *>;
template <>
struct type_caster<std::vector<map::Landmark *>> : ListCasterBase {
static handle cast(const std::vector<map::Landmark *> &src,
return_value_policy, handle parent) {
return ListCasterBase::cast(src, return_value_policy::reference, parent);
}
static handle cast(const std::vector<map::Landmark *> *src,
return_value_policy pol, handle parent) {
return cast(*src, pol, parent);
}
};
enum IteratorType {
// KeyIterator,
ValueIterator,
ItemIterator,
UniquePtrValueIterator,
UniquePtrIterator,
RefIterator,
RefValueIterator
};
template <typename Iterator, typename Sentinel, IteratorType it_type,
return_value_policy Policy>
struct sfm_iterator_state {
Iterator it;
Sentinel end;
bool first_or_done;
};
PYBIND11_NAMESPACE_END_(detail)
template <return_value_policy Policy = return_value_policy::reference_internal,
typename Iterator, typename Sentinel,
typename KeyType = decltype(&((*std::declval<Iterator>()).second)),
typename... Extra>
iterator make_ref_value_iterator(Iterator first, Sentinel last,
Extra &&... extra) {
typedef detail::sfm_iterator_state<Iterator, Sentinel,
detail::RefValueIterator, Policy>
state;
if (!detail::get_type_info(typeid(state), false)) {
class_<state>(handle(), "ref_value_iterator", pybind11::module_local())
.def("__iter__", [](state &s) -> state & { return s; })
.def("__next__",
[](state &s) -> KeyType {
if (!s.first_or_done) {
++s.it;
} else {
s.first_or_done = false;
}
if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}
return &(s.it->second);
},
std::forward<Extra>(extra)..., Policy);
}
return cast(state{first, last, true});
}
template <
return_value_policy Policy = return_value_policy::reference_internal,
typename Iterator, typename Sentinel,
typename KeyType =
pybind11::tuple, // decltype(&((*std::declval<Iterator>()).second)),
typename... Extra>
iterator make_ref_iterator(Iterator first, Sentinel last, Extra &&... extra) {
typedef detail::sfm_iterator_state<Iterator, Sentinel, detail::ValueIterator,
Policy>
state;
if (!detail::get_type_info(typeid(state), false)) {
class_<state>(handle(), "ref_iterator", pybind11::module_local())
.def("__iter__", [](state &s) -> state & { return s; })
.def("__next__",
[](state &s) -> KeyType {
if (!s.first_or_done) {
++s.it;
} else {
s.first_or_done = false;
}
if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}
return pybind11::make_tuple(s.it->first, &(s.it->second));
},
std::forward<Extra>(extra)..., Policy);
}
return cast(state{first, last, true});
}
PYBIND11_NAMESPACE_END_(PYBIND11_NAMESPACE)