Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions glue/cirq/stimcirq/_obs_annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Tuple

import cirq
import stim
Expand All @@ -16,7 +16,7 @@ def __init__(
*,
parity_keys: Iterable[str] = (),
relative_keys: Iterable[int] = (),
pauli_keys: Iterable[str] = (),
pauli_keys: Iterable[tuple[cirq.Qid, str]] | Iterable[str] = (),
Comment thread
AlexBourassa marked this conversation as resolved.
Outdated
observable_index: int,
):
"""
Expand All @@ -29,15 +29,31 @@ def __init__(
"""
self.parity_keys = frozenset(parity_keys)
self.relative_keys = frozenset(relative_keys)
self.pauli_keys = frozenset(pauli_keys)
_pauli_keys = []
for k in pauli_keys:
if isinstance(k, str):
# For backward compatibility
_pauli_keys.append((cirq.LineQubit(int(k[1:])), k[0]))
else:
_pauli_keys.append(tuple(k))
self.pauli_keys = frozenset(_pauli_keys)
Comment thread
Strilanc marked this conversation as resolved.
self.observable_index = observable_index

@property
def qubits(self) -> Tuple[cirq.Qid, ...]:
return ()
return tuple(sorted(q for q, _ in self.pauli_keys))

def with_qubits(self, *new_qubits) -> 'CumulativeObservableAnnotation':
return self
if len(self.qubits) == len(new_qubits):
pauli_map = dict(self.pauli_keys)
return CumulativeObservableAnnotation(
parity_keys=self.parity_keys,
relative_keys=self.relative_keys,
pauli_keys=tuple((new_q, pauli_map[q]) for new_q, q in zip(new_qubits, self.qubits)),
observable_index=self.observable_index,
)

raise ValueError("Number of qubits does not match")

def _value_equality_values_(self) -> Any:
return self.parity_keys, self.relative_keys, self.pauli_keys, self.observable_index
Expand Down Expand Up @@ -85,6 +101,7 @@ def _stim_conversion_(
edit_measurement_key_lengths: List[Tuple[str, int]],
have_seen_loop: bool = False,
tag: str,
targets: list[int],
**kwargs,
):
# Ideally these references would all be resolved ahead of time, to avoid the redundant
Expand All @@ -109,10 +126,11 @@ def _stim_conversion_(
rec_targets.append(stim.target_rec(-1 - offset))
if not remaining:
break
pauli_map = dict(self.pauli_keys)
rec_targets.extend(
[
stim.target_pauli(qubit_index=int(k[1:]), pauli=k[0])
for k in sorted(self.pauli_keys)
stim.target_pauli(qubit_index=tid, pauli=pauli_map[q])
for q, tid in zip(self.qubits, targets)
]
)
if remaining:
Expand Down
23 changes: 15 additions & 8 deletions glue/cirq/stimcirq/_obs_annotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ def test_json_serialization():
assert c == c2

def test_json_serialization_with_pauli_keys():
pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")]
c = cirq.Circuit(
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=pauli_keys
),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=pauli_keys),
)
json = cirq.to_json(c)
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
Expand All @@ -208,13 +209,19 @@ def test_json_serialization_with_pauli_keys():
def test_json_backwards_compat_exact():
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5)
packed_v1 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}'
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v2 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
assert cirq.read_json(json_text=packed_v1, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3

# With pauli_keys
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=["X0", "Y1", "Z2"])
pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")]
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=pauli_keys)
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n "X0",\n "Y1",\n "Z2"\n ],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n [\n {\n "cirq_type": "LineQubit",\n "x": 0\n },\n "X"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 1\n },\n "Y"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 2\n },\n "Z"\n ]\n ],\n "relative_keys": [\n -2\n ]\n}'

assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3
4 changes: 2 additions & 2 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,13 @@ def coords_after_offset(

def resolve_measurement_record_keys(
self, targets: Iterable[stim.GateTarget]
) -> Tuple[List[str], List[int], List[str]]:
) -> Tuple[List[str], List[int], List[tuple[cirq.Qid, str]]]:
pauli_targets, meas_targets = [], []
for t in targets:
if t.is_measurement_record_target:
meas_targets.append(t)
else:
pauli_targets.append(f'{t.pauli_type}{t.value}')
pauli_targets.append((cirq.LineQubit(t.value), t.pauli_type))

if self.have_seen_loop:
return [], [t.value for t in meas_targets], pauli_targets
Expand Down
1 change: 1 addition & 0 deletions glue/cirq/stimcirq/_stim_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def test_round_trip_with_pauli_obs():
stim_circuit = stim.Circuit("""
QUBIT_COORDS(5, 5) 0
R 0
TICK
OBSERVABLE_INCLUDE(0) X0
TICK
H 0
Expand Down
Loading