forked from quantumlib/tesseract-decoder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimplex.pybind.h
More file actions
110 lines (100 loc) · 4.4 KB
/
simplex.pybind.h
File metadata and controls
110 lines (100 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _SIMPLEX_PYBIND_H
#define _SIMPLEX_PYBIND_H
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iostream>
#include "common.h"
#include "simplex.h"
#include "stim_utils.pybind.h"
namespace py = pybind11;
namespace {
SimplexConfig simplex_config_maker(py::object dem, bool parallelize = false,
size_t window_length = 0, size_t window_slide_length = 0,
py::object verbose_callback = py::none()) {
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
SimplexConfig cfg;
cfg.dem = input_dem;
cfg.parallelize = parallelize;
cfg.window_length = window_length;
cfg.window_slide_length = window_slide_length;
std::function<void(const std::string&)> cb;
bool active = false;
if (!verbose_callback.is_none()) {
py::function f = verbose_callback;
cb = [f](const std::string& s) {
py::gil_scoped_acquire gil;
f(s);
};
active = true;
}
cfg.verbose_callback = cb;
cfg.log_stream = CallbackStream(active, cfg.verbose_callback);
return cfg;
}
}; // namespace
void add_simplex_module(py::module& root) {
auto m =
root.def_submodule("simplex", "Module containing the SimplexDecoder and related methods");
py::class_<SimplexConfig>(m, "SimplexConfig")
.def(py::init(&simplex_config_maker), py::arg("dem"), py::arg("parallelize") = false,
py::arg("window_length") = 0, py::arg("window_slide_length") = 0,
py::arg("verbose_callback") = py::none())
.def_property("dem", &dem_getter<SimplexConfig>, &dem_setter<SimplexConfig>)
.def_readwrite("parallelize", &SimplexConfig::parallelize)
.def_readwrite("window_length", &SimplexConfig::window_length)
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)
.def("windowing_enabled", &SimplexConfig::windowing_enabled)
.def("__str__", &SimplexConfig::str);
py::class_<SimplexDecoder>(m, "SimplexDecoder")
.def(py::init<SimplexConfig>(), py::arg("config"))
.def_readwrite("config", &SimplexDecoder::config)
.def_readwrite("errors", &SimplexDecoder::errors)
.def_readwrite("num_detectors", &SimplexDecoder::num_detectors)
.def_readwrite("num_observables", &SimplexDecoder::num_observables)
.def_readwrite("predicted_errors_buffer", &SimplexDecoder::predicted_errors_buffer)
.def_readwrite("error_masks", &SimplexDecoder::error_masks)
.def_readwrite("start_time_to_errors", &SimplexDecoder::start_time_to_errors)
.def_readwrite("end_time_to_errors", &SimplexDecoder::end_time_to_errors)
.def_readonly("low_confidence_flag", &SimplexDecoder::low_confidence_flag)
.def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections"))
.def(
"get_observables_from_errors",
[](SimplexDecoder& self, const std::vector<size_t>& predicted_errors) {
std::vector<bool> result(self.num_observables, false);
const auto& indices = self.get_flipped_observables(predicted_errors);
for (int index : indices) {
result[index] = true;
}
return result;
},
py::arg("predicted_errors"))
.def("cost_from_errors", &SimplexDecoder::cost_from_errors, py::arg("predicted_errors"))
.def(
"decode",
[](SimplexDecoder& self, const std::vector<uint64_t>& detections) {
std::vector<bool> result(self.num_observables, false);
self.decode(detections);
for (size_t ei : self.predicted_errors_buffer) {
for (int obs_index : self.errors[ei].symptom.observables) {
result[obs_index] = result[obs_index] ^ true;
}
}
return result;
},
py::arg("detections"));
}
#endif