-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathtesseract.pybind.h
More file actions
492 lines (443 loc) · 22.2 KB
/
tesseract.pybind.h
File metadata and controls
492 lines (443 loc) · 22.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _TESSERACT_PYBIND_H
#define _TESSERACT_PYBIND_H
#include <pybind11/iostream.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "stim_utils.pybind.h"
#include "tesseract.h"
#include "utils.h"
namespace py = pybind11;
namespace tesseract_decoder {
namespace {
// Helper function to compile the decoder.
std::unique_ptr<TesseractDecoder> _compile_tesseract_decoder_helper(const TesseractConfig& self) {
return std::make_unique<TesseractDecoder>(self);
}
TesseractConfig tesseract_config_maker_no_dem(
int det_beam = INF_DET_BEAM, bool beam_climbing = false, bool no_revisit_dets = false,
bool verbose = false, bool merge_errors = true,
size_t pqlimit = std::numeric_limits<size_t>::max(),
std::vector<std::vector<size_t>> det_orders = std::vector<std::vector<size_t>>(),
double det_penalty = 0.0, bool create_visualization = false) {
stim::DetectorErrorModel empty_dem;
if (det_orders.empty()) {
det_orders = build_det_orders(empty_dem, 20, DetOrder::DetBFS, 2384753);
}
return TesseractConfig({empty_dem, det_beam, beam_climbing, no_revisit_dets, verbose,
merge_errors, pqlimit, det_orders, det_penalty, create_visualization});
}
TesseractConfig tesseract_config_maker(
py::object dem, int det_beam = INF_DET_BEAM, bool beam_climbing = false,
bool no_revisit_dets = false, bool verbose = false, bool merge_errors = true,
size_t pqlimit = std::numeric_limits<size_t>::max(),
std::vector<std::vector<size_t>> det_orders = std::vector<std::vector<size_t>>(),
double det_penalty = 0.0, bool create_visualization = false) {
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
if (det_orders.empty()) {
det_orders = build_det_orders(input_dem, 20, DetOrder::DetBFS, 2384753);
}
return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets, verbose,
merge_errors, pqlimit, det_orders, det_penalty, create_visualization});
}
}; // namespace
void add_tesseract_module(py::module& root) {
auto m = root.def_submodule("tesseract", "Module containing the tesseract algorithm");
m.attr("INF_DET_BEAM") = INF_DET_BEAM;
m.doc() = "A sentinel value indicating an infinite beam size for the decoder.";
auto py_tesseract_config = py::class_<TesseractConfig>(m, "TesseractConfig", R"pbdoc(
Configuration object for the `TesseractDecoder`.
This class holds all the parameters needed to initialize and configure a
Tesseract decoder instance.
)pbdoc");
auto py_tesseract_decoder = py::class_<TesseractDecoder>(m, "TesseractDecoder", R"pbdoc(
A class that implements the Tesseract decoding algorithm.
It can decode syndromes from a `stim.DetectorErrorModel` to predict
which observables have been flipped.
)pbdoc");
py_tesseract_config
.def(py::init<>(), R"pbdoc(
Default constructor for TesseractConfig.
Creates a new instance with default parameter values.
)pbdoc")
.def(py::init(&tesseract_config_maker_no_dem), py::arg("det_beam") = 5,
py::arg("beam_climbing") = false, py::arg("no_revisit_dets") = true,
py::arg("verbose") = false, py::arg("merge_errors") = true, py::arg("pqlimit") = 200000,
py::arg("det_orders") = std::vector<std::vector<size_t>>(), py::arg("det_penalty") = 0.0,
py::arg("create_visualization") = false,
R"pbdoc(
The constructor for the `TesseractConfig` class without a `dem` argument.
This creates an empty `DetectorErrorModel` by default.
Parameters
----------
det_beam : int, default=INF_DET_BEAM
Beam cutoff that specifies the maximum number of detection events a search state can have.
beam_climbing : bool, default=False
If True, enables a beam climbing heuristic.
no_revisit_dets : bool, default=False
If True, prevents the decoder from revisiting a syndrome pattern more than once.
verbose : bool, default=False
If True, enables verbose logging from the decoder.
merge_errors : bool, default=True
If True, merges error channels that have identical syndrome patterns.
pqlimit : int, default=max_size_t
The maximum size of the priority queue.
det_orders : list[list[int]], default=empty
A list of detector orderings to use for decoding. If empty, the decoder
will generate its own orderings.
det_penalty : float, default=0.0
A penalty value added to the cost of each detector visited.
create_visualization: bool, defualt=False
Whether to record the information needed to create a visualization or not.
)pbdoc")
.def(py::init(&tesseract_config_maker), py::arg("dem"), py::arg("det_beam") = 5,
py::arg("beam_climbing") = false, py::arg("no_revisit_dets") = true,
py::arg("verbose") = false, py::arg("merge_errors") = true, py::arg("pqlimit") = 200000,
py::arg("det_orders") = std::vector<std::vector<size_t>>(), py::arg("det_penalty") = 0.0,
py::arg("create_visualization") = false,
R"pbdoc(
The constructor for the `TesseractConfig` class.
Parameters
----------
dem : stim.DetectorErrorModel
The detector error model to be decoded.
det_beam : int, default=INF_DET_BEAM
Beam cutoff that specifies the maximum number of detection events a search state can have.
beam_climbing : bool, default=False
If True, enables a beam climbing heuristic.
no_revisit_dets : bool, default=False
If True, prevents the decoder from revisiting a syndrome pattern more than once.
verbose : bool, default=False
If True, enables verbose logging from the decoder.
merge_errors : bool, default=True
If True, merges error channels that have identical syndrome patterns.
pqlimit : int, default=max_size_t
The maximum size of the priority queue.
det_orders : list[list[int]], default=empty
A list of detector orderings to use for decoding. If empty, the decoder
will generate its own orderings.
det_penalty : float, default=0.0
A penalty value added to the cost of each detector visited.
create_visualization: bool, defualt=False
Whether to record the information needed to create a visualization or not.
)pbdoc")
.def_property("dem", &dem_getter<TesseractConfig>, &dem_setter<TesseractConfig>,
"The `stim.DetectorErrorModel` that defines the error channels and detectors.")
.def_readwrite("det_beam", &TesseractConfig::det_beam,
"Beam cutoff argument for the beam search.")
.def_readwrite("beam_climbing", &TesseractConfig::beam_climbing,
"Whether to use a beam climbing heuristic.")
.def_readwrite("no_revisit_dets", &TesseractConfig::no_revisit_dets,
"Whether to prevent revisiting same syndrome patterns during decoding.")
.def_readwrite("verbose", &TesseractConfig::verbose,
"If True, the decoder will print verbose output.")
.def_readwrite("merge_errors", &TesseractConfig::merge_errors,
"If True, merges error channels that have identical syndrome patterns.")
.def_readwrite("pqlimit", &TesseractConfig::pqlimit,
"The maximum size of the priority queue.")
.def_readwrite("det_orders", &TesseractConfig::det_orders,
"A list of pre-specified detector orderings.")
.def_readwrite("det_penalty", &TesseractConfig::det_penalty,
"The penalty cost added for each detector.")
.def_readwrite("create_visualization", &TesseractConfig::create_visualization,
"If True, records necessary information to create visualization.")
.def("__str__", &TesseractConfig::str)
.def("compile_decoder", &_compile_tesseract_decoder_helper,
py::return_value_policy::take_ownership,
R"pbdoc(
Compiles the configuration into a new `TesseractDecoder` instance.
Returns
-------
TesseractDecoder
A new `TesseractDecoder` instance configured with the current
settings.
)pbdoc")
.def(
"compile_decoder_for_dem",
[](TesseractConfig& self, py::object dem) {
self.dem = parse_py_object<stim::DetectorErrorModel>(dem);
return std::make_unique<TesseractDecoder>(self);
},
py::arg("dem"), py::return_value_policy::take_ownership, R"pbdoc(
Compiles the configuration into a new `TesseractDecoder` instance
for a given `dem` object.
Parameters
----------
dem : stim.DetectorErrorModel
The detector error model to use for the decoder.
Returns
-------
TesseractDecoder
A new `TesseractDecoder` instance configured with the
provided `dem` and the other settings from this
`TesseractConfig` object.
)pbdoc");
py_tesseract_decoder
.def(py::init<TesseractConfig>(), py::arg("config"), R"pbdoc(
The constructor for the `TesseractDecoder` class.
Parameters
----------
config : TesseractConfig
The configuration object for the decoder.
)pbdoc")
.def(
"decode_to_errors",
[](TesseractDecoder& self, const py::array_t<bool>& syndrome) {
if ((size_t)syndrome.size() != self.num_detectors) {
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
") does not match the number of detectors in the decoder (" +
std::to_string(self.num_detectors) + ").";
throw std::invalid_argument(msg);
}
std::vector<uint64_t> detections;
auto syndrome_unchecked = syndrome.unchecked<1>();
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
if (syndrome_unchecked(i)) {
detections.push_back(i);
}
}
self.decode_to_errors(detections);
return self.predicted_errors_buffer;
},
py::arg("syndrome"),
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
R"pbdoc(
Decodes a single shot to a list of error indices.
Parameters
----------
syndrome : np.ndarray
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
The length of the array should match the number of detectors in the DEM.
Returns
-------
list[int]
A list of predicted error indices from the original flattened DEM.
)pbdoc")
.def(
"decode_to_errors",
[](TesseractDecoder& self, const py::array_t<bool>& syndrome, size_t det_order,
size_t det_beam) {
if ((size_t)syndrome.size() != self.num_detectors) {
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
") does not match the number of detectors in the decoder (" +
std::to_string(self.num_detectors) + ").";
throw std::invalid_argument(msg);
}
std::vector<uint64_t> detections;
auto syndrome_unchecked = syndrome.unchecked<1>();
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
if (syndrome_unchecked(i)) {
detections.push_back(i);
}
}
self.decode_to_errors(detections, det_order, det_beam);
return self.predicted_errors_buffer;
},
py::arg("syndrome"), py::arg("det_order"), py::arg("det_beam"),
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
R"pbdoc(
Decodes a single shot using a specific detector ordering and beam size.
Parameters
----------
syndrome : np.ndarray
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
The length of the array should match the number of detectors in the DEM.
det_order : int
The index of the detector ordering to use.
det_beam : int
The beam size to use during the decoding.
Returns
-------
list[int]
A list of predicted error indices from the original flattened DEM.
)pbdoc")
.def(
"get_observables_from_errors",
[](TesseractDecoder& self, const std::vector<size_t>& predicted_errors) {
std::vector<bool> result(self.num_observables, false);
for (int obs_index : self.get_flipped_observables(predicted_errors)) {
result[obs_index] = result[obs_index] ^ true;
}
return result;
},
py::arg("predicted_errors"), R"pbdoc(
Converts a list of predicted error indices into a list of
flipped logical observables.
Parameters
----------
predicted_errors : list[int]
A list of integers representing error indices from the original flattened DEM.
Returns
-------
list[bool]
A list of booleans, where each boolean corresponds to a
logical observable and is `True` if the observable was flipped.
)pbdoc")
.def("cost_from_errors", &TesseractDecoder::cost_from_errors, py::arg("predicted_errors"),
R"pbdoc(
Calculates the sum of the likelihood costs of the predicted errors.
The likelihood cost of an error with probability p is log((1 - p) / p).
Parameters
----------
predicted_errors : list[int]
A list of integers representing error indices from the original flattened DEM.
Returns
-------
float
A float representing the sum of the likelihood costs of the
predicted errors.
)pbdoc")
.def(
"decode_from_detection_events",
[](TesseractDecoder& self, const std::vector<uint64_t>& detections) {
std::vector<char> result(self.num_observables, false);
self.decode(detections);
for (int obs_index : self.get_flipped_observables(self.predicted_errors_buffer)) {
result[obs_index] = result[obs_index] ^ true;
}
return py::array(py::dtype::of<bool>(), result.size(), result.data());
},
py::arg("detections"),
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
R"pbdoc(
Decodes a single shot from a list of detection events.
Parameters
----------
detections : list[int]
A list of indices corresponding to the detectors that were
fired. This input represents a single measurement shot.
Returns
-------
np.ndarray
A 1D NumPy array of booleans. Each boolean value indicates whether the
decoder predicts that the corresponding logical observable has been flipped.
)pbdoc")
.def(
"decode",
[](TesseractDecoder& self, const py::array_t<bool>& syndrome) {
if ((size_t)syndrome.size() != self.num_detectors) {
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
") does not match the number of detectors in the decoder (" +
std::to_string(self.num_detectors) + ").";
throw std::invalid_argument(msg);
}
std::vector<uint64_t> detections;
auto syndrome_unchecked = syndrome.unchecked<1>();
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
if (syndrome_unchecked(i)) {
detections.push_back(i);
}
}
self.decode(detections);
// Note: `std::vector<bool>` is a special C++ template that does not
// provide a contiguous memory block, which is required by `pybind11`
// for direct NumPy array creation. Therefore, I use `std::vector<char>`
// instead to ensure compatibility with `py::array`.
std::vector<char> result(self.num_observables, 0);
for (int obs_index : self.get_flipped_observables(self.predicted_errors_buffer)) {
result[obs_index] = result[obs_index] ^ true;
}
return py::array(py::dtype::of<bool>(), result.size(), result.data());
},
py::arg("syndrome"),
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
R"pbdoc(
Decodes a single shot.
Parameters
----------
syndrome : np.ndarray
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
The length of the array should match the number of detectors in the DEM.
Returns
-------
np.ndarray
A 1D NumPy array of booleans indicating which observables are flipped.
The length of the array matches the number of observables.
)pbdoc")
.def(
"decode_batch",
[](TesseractDecoder& self, const py::array_t<bool>& syndromes) {
// Check the dimensions of the `syndromes` argument.
if (syndromes.ndim() != 2) {
throw std::runtime_error("Input syndromes must be a 2D NumPy array.");
}
// Retrieve the number of shots, detectors and the syndrome patterns.
auto syndromes_unchecked = syndromes.unchecked<2>();
size_t num_shots = syndromes_unchecked.shape(0);
size_t num_detectors = syndromes_unchecked.shape(1);
if (num_detectors != self.num_detectors) {
std::string msg = "The number of detectors in the input array (" +
std::to_string(num_detectors) +
") does not match the number of detectors in the decoder (" +
std::to_string(self.num_detectors) + ").";
throw std::invalid_argument(msg);
}
// Allocate the result array.
py::array_t<bool> result({num_shots, self.num_observables});
result.attr("fill")(0);
auto result_unchecked = result.mutable_unchecked<2>();
// Process and decode each shot.
for (size_t i = 0; i < num_shots; ++i) {
std::vector<uint64_t> detections;
for (size_t j = 0; j < num_detectors; ++j) {
if (syndromes_unchecked(i, j)) {
detections.push_back(j);
}
}
self.decode(detections);
// Collect results for the current shot being decoded.
for (int obs_index : self.get_flipped_observables(self.predicted_errors_buffer)) {
result_unchecked(i, obs_index) ^= 1;
}
}
return result;
},
py::arg("syndromes"),
R"pbdoc(
Decodes a batch of shots.
Parameters
----------
syndromes : np.ndarray
A 2D NumPy array of booleans where each row corresponds to a shot and
each column corresponds to a logical observable. Each row is the decoder's prediction of which observables were flipped in the shot. The shape is
a new array with num_detectors size.
Returns
-------
np.ndarray
A 2D NumPy array of booleans where each row corresponds to a shot and
that short specifies which logical observable are flipped. The shape is
(num_shots, num_observables).
)pbdoc")
.def_readwrite("config", &TesseractDecoder::config,
"The configuration used to create this decoder.")
.def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag,
"A flag indicating if the decoder's prediction has low confidence.")
.def_readwrite(
"predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer,
"A buffer containing the predicted errors from the most recent decode operation.")
.def_readwrite("errors", &TesseractDecoder::errors,
"The list of all errors in the detector error model.")
.def_readwrite("num_observables", &TesseractDecoder::num_observables,
"The total number of logical observables in the detector error model.")
.def_readwrite("num_detectors", &TesseractDecoder::num_detectors,
"The total number of detectors in the detector error model.")
.def_readonly("visualizer", &TesseractDecoder::visualizer,
"An object that can (if config.create_visualization=True) be used to generate "
"visualization of the algorithm");
}
} // namespace tesseract_decoder
#endif