-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathstim_utils.pybind.h
More file actions
73 lines (58 loc) · 2.31 KB
/
stim_utils.pybind.h
File metadata and controls
73 lines (58 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#ifndef _STIM_UTILS_PYBIND_H
#define _STIM_UTILS_PYBIND_H
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "stim.h"
namespace tesseract_decoder {
namespace {
namespace py = pybind11;
}
template <typename T>
py::object make_py_object(const T cpp_obj, const char* py_name) {
auto stim_lib = py::module::import("stim");
return stim_lib.attr(py_name)(cpp_obj.str());
}
template <typename T>
T parse_py_object(py::object py_obj) {
std::string obj_str = py::cast<std::string>(py_obj.attr("__str__")());
return T(obj_str);
}
inline stim::DemInstructionType parse_dit(std::string dit_str) {
if (dit_str == "error") return stim::DemInstructionType::DEM_ERROR;
if (dit_str == "detector") return stim::DemInstructionType::DEM_DETECTOR;
if (dit_str == "logical_observable") return stim::DemInstructionType::DEM_LOGICAL_OBSERVABLE;
if (dit_str == "shift_detectors") return stim::DemInstructionType::DEM_SHIFT_DETECTORS;
if (dit_str == "repeat") return stim::DemInstructionType::DEM_REPEAT_BLOCK;
throw std::invalid_argument("unknown dem instruction type: " + dit_str);
return stim::DemInstructionType::DEM_DETECTOR;
}
inline stim::DemTarget parse_py_dem_target(py::object py_obj) {
return stim::DemTarget::from_text(py::cast<std::string>(py_obj.attr("__str__")()));
}
inline stim::DemInstruction parse_py_dem_instruction(py::object py_obj, std::vector<double>& args,
std::vector<stim::DemTarget>& targets) {
for (auto t : py_obj.attr("args_copy")()) args.push_back(t.cast<double>());
stim::SpanRef args_ref(args);
for (auto t : py_obj.attr("targets_copy")())
targets.push_back(parse_py_dem_target(t.cast<py::object>()));
stim::SpanRef targets_ref(targets);
auto ty = parse_dit(py::cast<std::string>(py_obj.attr("type")));
std::string tag = py::cast<std::string>(py_obj.attr("tag"));
auto di = stim::DemInstruction();
di.arg_data = args_ref;
di.target_data = targets_ref;
di.tag = tag;
di.type = ty;
return di;
}
template <typename T>
py::object dem_getter(const T& config) {
return make_py_object(config.dem, "DetectorErrorModel");
}
template <typename T>
void dem_setter(T& config, py::object dem) {
config.dem = parse_py_object<stim::DetectorErrorModel>(dem);
}
} // namespace tesseract_decoder
#endif