From 2cb8d0591933ee1eac04a2288e9a9327f6f5bda5 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Thu, 12 Mar 2026 16:51:11 -0400 Subject: [PATCH 1/3] refactor(circuit): Only flatten circuit when a ZX graph is constructed --- src/tsim/circuit.py | 28 +++++------ src/tsim/utils/diagram.py | 19 +++++-- src/tsim/utils/encoder.py | 3 +- test/unit/core/test_parse.py | 31 ++++++++++++ test/unit/test_circuit.py | 88 ++++++++++++++++++++++++++++----- test/unit/utils/test_diagram.py | 13 +++++ 6 files changed, 151 insertions(+), 31 deletions(-) diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index 611b6e2..d4d60a3 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -38,7 +38,7 @@ def __init__(self, stim_program_text: str = ""): empty circuit. """ - self._stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)).flattened() + self._stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)) @classmethod def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit: @@ -52,7 +52,7 @@ def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit: """ c = cls.__new__(cls) - c._stim_circ = stim_circuit.flattened() + c._stim_circ = stim_circuit return c def append_from_stim_program_text(self, stim_program_text: str) -> None: @@ -63,7 +63,6 @@ def append_from_stim_program_text(self, stim_program_text: str) -> None: self._stim_circ.append_from_stim_program_text( shorthand_to_stim(stim_program_text) ) - self._stim_circ = self._stim_circ.flattened() @overload def append( @@ -158,8 +157,6 @@ def append( self._stim_circ.append(name=name, targets=targets, arg=arg, tag=tag) # type: ignore else: self._stim_circ.append(name=name) - if isinstance(name, stim.CircuitRepeatBlock): - self._stim_circ = self._stim_circ.flattened() @classmethod def from_file(cls, filename: str) -> Circuit: @@ -174,7 +171,7 @@ def from_file(cls, filename: str) -> Circuit: """ with open(filename, "r", encoding="utf-8") as f: stim_program_text = f.read() - stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)).flattened() + stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)) return cls.from_stim_program(stim_circ) def __repr__(self) -> str: @@ -212,7 +209,6 @@ def __add__(self, other: Circuit | stim.Circuit) -> Circuit: def __imul__(self, repetitions: int) -> Circuit: """Repeat this circuit in-place.""" self._stim_circ *= repetitions - self._stim_circ = self._stim_circ.flattened() return self def __mul__(self, repetitions: int) -> Circuit: @@ -227,7 +223,7 @@ def __rmul__(self, repetitions: int) -> Circuit: def __getitem__( self, index_or_slice: int, - ) -> stim.CircuitInstruction: ... + ) -> Union[stim.CircuitInstruction, stim.CircuitRepeatBlock]: ... @overload def __getitem__( @@ -371,27 +367,29 @@ def num_ticks( def pop( self, index: int = -1, - ) -> stim.CircuitInstruction: + ) -> Union[stim.CircuitInstruction, stim.CircuitRepeatBlock]: """Pops an operation from the end of the circuit, or at the given index. Args: index: Defaults to -1 (end of circuit). The index to pop from. Returns: - The popped instruction. + The popped instruction or repeat block. Raises: IndexError: The given index is outside the bounds of the circuit. """ - el = self._stim_circ.pop(index) - assert not isinstance(el, stim.CircuitRepeatBlock) - return el + return self._stim_circ.pop(index) def copy(self) -> Circuit: """Create a copy of this circuit.""" return Circuit.from_stim_program(self._stim_circ.copy()) + def flattened(self) -> Circuit: + """Return a copy of the circuit with all repeat blocks expanded.""" + return Circuit.from_stim_program(self._stim_circ.flattened()) + def without_noise(self) -> Circuit: """Return a copy of the circuit with all noise removed.""" return Circuit.from_stim_program(self._stim_circ.without_noise()) @@ -399,7 +397,7 @@ def without_noise(self) -> Circuit: def without_annotations(self) -> Circuit: """Return a copy of the circuit with all annotations removed.""" circ = stim.Circuit() - for instr in self._stim_circ: + for instr in self._stim_circ.flattened(): assert not isinstance(instr, stim.CircuitRepeatBlock) if instr.name in ["OBSERVABLE_INCLUDE", "DETECTOR"]: continue @@ -688,7 +686,7 @@ def cast_to_stim(self) -> stim.Circuit: def inverse(self) -> Circuit: """Return the inverse of the circuit.""" - inv_stim_raw = self._stim_circ.inverse() + inv_stim_raw = self._stim_circ.flattened().inverse() # Stim will only invert Clifford gates (and S[T] / S_DAG[T]) # Post-process to fix non-Clifford rotation gates stored as I[tag] diff --git a/src/tsim/utils/diagram.py b/src/tsim/utils/diagram.py index 34dc988..35b23c3 100644 --- a/src/tsim/utils/diagram.py +++ b/src/tsim/utils/diagram.py @@ -210,11 +210,24 @@ def tagged_gates_to_placeholder( of the I_ERROR placeholder gates to the actual gate names. """ - modified_circ = stim.Circuit() replace_dict: dict[float, GateLabel] = {} + modified_circ = _replace_tagged_gates(circuit, replace_dict) + return modified_circ, replace_dict + + +def _replace_tagged_gates( + circuit: stim.Circuit, + replace_dict: dict[float, GateLabel], +) -> stim.Circuit: + modified_circ = stim.Circuit() for instr in circuit: - assert not isinstance(instr, stim.CircuitRepeatBlock) + if isinstance(instr, stim.CircuitRepeatBlock): + modified_body = _replace_tagged_gates(instr.body_copy(), replace_dict) + modified_circ.append( + stim.CircuitRepeatBlock(instr.repeat_count, modified_body) + ) + continue # Handle T gates (S[T] and S_DAG[T]) if instr.tag == "T" and instr.name in ["S", "S_DAG"]: @@ -255,7 +268,7 @@ def tagged_gates_to_placeholder( continue modified_circ.append(instr) - return modified_circ, replace_dict + return modified_circ def render_svg( diff --git a/src/tsim/utils/encoder.py b/src/tsim/utils/encoder.py index f268bdc..b8dcbc9 100644 --- a/src/tsim/utils/encoder.py +++ b/src/tsim/utils/encoder.py @@ -27,8 +27,7 @@ def _transform_circuit( observables: list[list[int]] | None = None, ) -> stim.Circuit: """Expand and duplicate instructions with broadcast targets for encoding.""" - stim_circ = tsim.Circuit(program_text)._stim_circ - stim_circ = tsim.Circuit(program_text)._stim_circ + stim_circ = tsim.Circuit(program_text)._stim_circ.flattened() mod_circ = stim.Circuit() for instr in stim_circ: diff --git a/test/unit/core/test_parse.py b/test/unit/core/test_parse.py index d77f0ee..51e7696 100644 --- a/test/unit/core/test_parse.py +++ b/test/unit/core/test_parse.py @@ -127,6 +127,37 @@ def test_chain_multiple_qubits(self): assert_allclose(probs[4], 0.2, rtol=1e-5) # Third error +class TestParseWithRepeatBlocks: + """Tests for parsing circuits that contain REPEAT blocks.""" + + def test_parse_circuit_with_repeat_block(self): + """parse_stim_circuit should flatten repeat blocks transparently.""" + flat_circuit = stim.Circuit("H 0\nCNOT 0 1\nH 0\nCNOT 0 1\nH 0\nCNOT 0 1") + repeat_circuit = stim.Circuit("REPEAT 3 {\n H 0\n CNOT 0 1\n}") + + b_flat = parse_stim_circuit(flat_circuit) + b_repeat = parse_stim_circuit(repeat_circuit) + + assert len(b_flat.graph.vertices()) == len(b_repeat.graph.vertices()) + assert list(b_flat.graph.edges()) == list(b_repeat.graph.edges()) + + def test_parse_repeat_block_with_measurements(self): + """Repeat blocks containing measurements should parse correctly.""" + circuit = stim.Circuit("REPEAT 3 {\n H 0\n M 0\n}") + b = parse_stim_circuit(circuit) + assert len(b.rec) == 3 + + def test_parse_nested_repeat_blocks(self): + """Nested repeat blocks should be fully flattened by the parser.""" + circuit = stim.Circuit("REPEAT 2 {\n REPEAT 3 {\n H 0\n }\n}") + flat = stim.Circuit("H 0\nH 0\nH 0\nH 0\nH 0\nH 0") + + b_nested = parse_stim_circuit(circuit) + b_flat = parse_stim_circuit(flat) + + assert len(b_nested.graph.vertices()) == len(b_flat.graph.vertices()) + + class TestCorrelatedErrorState: """Tests for correlated error state management.""" diff --git a/test/unit/test_circuit.py b/test/unit/test_circuit.py index 452829b..612f8b0 100644 --- a/test/unit/test_circuit.py +++ b/test/unit/test_circuit.py @@ -227,7 +227,7 @@ def test_circuit_mul(): c1 = Circuit("H 0") c1_stim = c1._stim_circ.copy() c2 = c1 * 3 - assert c2._stim_circ == (c1_stim * 3).flattened() + assert c2._stim_circ == c1_stim * 3 def test_circuit_without_noise(): @@ -319,8 +319,7 @@ def test_circuit_imul(): """Test in-place multiplication.""" c = Circuit("H 0") c *= 3 - expected = Circuit("H 0\nH 0\nH 0") - assert c == expected + assert c.flattened() == Circuit("H 0\nH 0\nH 0") def test_circuit_imul_zero(): @@ -334,8 +333,7 @@ def test_circuit_rmul(): """Test right multiplication (n * circuit).""" c = Circuit("H 0") result = 3 * c - expected = Circuit("H 0\nH 0\nH 0") - assert result == expected + assert result.flattened() == Circuit("H 0\nH 0\nH 0") def test_circuit_getitem_int(): @@ -630,7 +628,7 @@ def test_diagram_pyzx_scale_horizontally( assert hasattr(g, "vertices") -def test_circuit_append(): +def test_append(): c = Circuit() c.append("T", [0, 1]) assert str(c) == "T 0 1" @@ -648,23 +646,91 @@ def test_circuit_append(): assert "U3(0.3, 0.24, 0.49) 0" in str(c) -def test_circuit_append_circuit_instruction(): +def test_append_circuit_instruction(): c = Circuit() c.append(stim.CircuitInstruction("H", [0])) assert str(c) == "H 0" -def test_circuit_append_circuit_repeat_block(): +def test_append_circuit_repeat_block(): c = Circuit() block = stim.CircuitRepeatBlock(3, stim.Circuit("H 0")) c.append(block) - # Should be flattened - assert str(c) == "H 0 0 0" + assert str(c.flattened()) == "H 0 0 0" + assert len(c) == 1 # single repeat block -def test_circuit_append_circuit(): +def test_append_circuit(): c = Circuit() sub_c = stim.Circuit("H 0\nCNOT 0 1") c.append(sub_c) assert "H 0" in str(c) assert "CX 0 1" in str(c) or "CNOT 0 1" in str(c) + + +def test_append_repetition_code(): + stim_c = stim.Circuit.generated("repetition_code:memory", distance=2, rounds=4) + c = Circuit() + for instr in stim_c: + c.append(instr) + + assert str(c.flattened()) == str(stim_c.flattened()) + assert str(c) == str(stim_c) + + +def _circuit_with_repeat_block() -> Circuit: + """Helper: build a Circuit that contains a REPEAT block.""" + c = Circuit("H 0") + block = stim.CircuitRepeatBlock(5, stim.Circuit("CNOT 0 1\nTICK")) + c.append(block) + c.append("M", [0, 1]) + return c + + +def test_mul_preserves_repeat_block(): + """c * n should wrap in a repeat block, not flatten.""" + c = Circuit("H 0\nCNOT 0 1") + c2 = c * 4 + assert c2._stim_circ == c._stim_circ * 4 + # flattened form should equal the naive expansion + assert c2.flattened() == c + c + c + c + + +def test_imul_preserves_repeat_block(): + c = Circuit("H 0\nCNOT 0 1") + flat_4x = c + c + c + c + c *= 4 + assert c.flattened() == flat_4x + + +def test_getitem_repeat_block(): + """Indexing into a circuit may return a CircuitRepeatBlock.""" + c = _circuit_with_repeat_block() + item = c[1] + assert isinstance(item, stim.CircuitRepeatBlock) + assert item.repeat_count == 5 + + +def test_getitem_slice_with_repeat_block(): + c = _circuit_with_repeat_block() + sliced = c[0:2] + assert isinstance(sliced, Circuit) + assert len(sliced) == 2 + + +def test_pop_repeat_block(): + c = Circuit() + block = stim.CircuitRepeatBlock(3, stim.Circuit("X 0")) + c.append(block) + popped = c.pop() + assert isinstance(popped, stim.CircuitRepeatBlock) + assert popped.repeat_count == 3 + assert len(c) == 0 + + +def test_copy_preserves_repeat_block(): + c = _circuit_with_repeat_block() + c2 = c.copy() + assert c == c2 + assert c is not c2 + assert str(c) == str(c2) diff --git a/test/unit/utils/test_diagram.py b/test/unit/utils/test_diagram.py index 70349c8..6b154ee 100644 --- a/test/unit/utils/test_diagram.py +++ b/test/unit/utils/test_diagram.py @@ -87,3 +87,16 @@ def test_render_svg_labels_all_gates( # U3 label assert '3' in html + + +def test_diagram_repeat_block(): + c = Circuit( + """ + T 0 + REPEAT 100 { + T 0 + } + """ + ) + diagram = c.diagram("timeline-svg", height=150) + assert "REP100" in str(diagram) From 2930a23a2e506779140057384102ff2dfaec6b10 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Fri, 13 Mar 2026 13:24:31 -0400 Subject: [PATCH 2/3] fix(circuit): ensure circuit inverse correctly handles repeat blocks --- src/tsim/circuit.py | 66 ++++++++++++++++++++++----------------- test/unit/test_circuit.py | 12 +++++++ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index d4d60a3..9341976 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -52,7 +52,7 @@ def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit: """ c = cls.__new__(cls) - c._stim_circ = stim_circuit + c._stim_circ = stim_circuit.copy() return c def append_from_stim_program_text(self, stim_program_text: str) -> None: @@ -686,35 +686,43 @@ def cast_to_stim(self) -> stim.Circuit: def inverse(self) -> Circuit: """Return the inverse of the circuit.""" - inv_stim_raw = self._stim_circ.flattened().inverse() - # Stim will only invert Clifford gates (and S[T] / S_DAG[T]) - # Post-process to fix non-Clifford rotation gates stored as I[tag] - inv_stim = stim.Circuit() - for instr in inv_stim_raw: - assert not isinstance(instr, stim.CircuitRepeatBlock) - name = instr.name - tag = instr.tag - targets = [t.value for t in instr.targets_copy()] - args = instr.gate_args_copy() - - if name == "I" and tag: - result = parse_parametric_tag(tag) - if result is not None: - gate_name, params = result - if gate_name == "U3": - # U3(θ, φ, λ)⁻¹ = U3(-θ, -λ, -φ) - theta = float(-params["theta"]) - phi = float(-params["lambda"]) - lam = float(-params["phi"]) - new_tag = f"U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)" - else: - theta = float(-params["theta"]) - new_tag = f"{gate_name}(theta={theta}*pi)" - inv_stim.append("I", targets, args, tag=new_tag) + def fix_tags(circuit: stim.Circuit) -> stim.Circuit: + # Stim only inverts Clifford gates (and S[T] / S_DAG[T]). + # Non-Clifford rotations stored as I[tag] need their parameters + # negated (and reordered for U3). + result = stim.Circuit() + for instr in circuit: + if isinstance(instr, stim.CircuitRepeatBlock): + fixed = fix_tags(instr.body_copy()) + result.append(stim.CircuitRepeatBlock(instr.repeat_count, fixed)) continue - # All other instructions are correct from stim's inverse - inv_stim.append(instr) - + name = instr.name + tag = instr.tag + targets = [t.value for t in instr.targets_copy()] + args = instr.gate_args_copy() + + if name == "I" and tag: + parsed = parse_parametric_tag(tag) + if parsed is not None: + gate_name, params = parsed + if gate_name == "U3": + # U3(θ, φ, λ)⁻¹ = U3(-θ, -λ, -φ) + theta = float(-params["theta"]) + phi = float(-params["lambda"]) + lam = float(-params["phi"]) + new_tag = ( + f"U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)" + ) + else: + theta = float(-params["theta"]) + new_tag = f"{gate_name}(theta={theta}*pi)" + result.append("I", targets, args, tag=new_tag) + continue + + result.append(instr) + return result + + inv_stim = fix_tags(self._stim_circ.inverse()) return Circuit.from_stim_program(inv_stim) diff --git a/test/unit/test_circuit.py b/test/unit/test_circuit.py index 612f8b0..00615ee 100644 --- a/test/unit/test_circuit.py +++ b/test/unit/test_circuit.py @@ -569,6 +569,18 @@ def test_inverse_mixed_circuit(): assert unitaries_equal_up_to_global_phase(combined, np.eye(combined.shape[0])) +def test_inverse_with_repeat_block(): + c = Circuit("H 0\nT 0\nR_Z(0.22) 0\nCNOT 0 1") + c_repeat = c * 3 + c_inv = c_repeat.inverse() + # inverse should preserve repeat structure, not flatten + assert len(c_inv) == len(c_repeat) + assert isinstance(c_inv[0], stim.CircuitRepeatBlock) + assert c_inv.flattened() == c_repeat.flattened().inverse() + combined = (c_repeat + c_inv).to_matrix() + assert unitaries_equal_up_to_global_phase(combined, np.eye(combined.shape[0])) + + def test_diagram_timeline_svg(): c = Circuit("H 0\nCNOT 0 1\nM 0 1") diagram = c.diagram(type="timeline-svg") From 42a58fdb9f3e40284f44d4bd15cc35d7b972a6ca Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Fri, 13 Mar 2026 13:30:00 -0400 Subject: [PATCH 3/3] refactor(circuit): enhance `without_annotations` to handle repeat blocks --- src/tsim/circuit.py | 23 +++++++++++++++-------- test/unit/test_circuit.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index 9341976..90519f4 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -395,14 +395,21 @@ def without_noise(self) -> Circuit: return Circuit.from_stim_program(self._stim_circ.without_noise()) def without_annotations(self) -> Circuit: - """Return a copy of the circuit with all annotations removed.""" - circ = stim.Circuit() - for instr in self._stim_circ.flattened(): - assert not isinstance(instr, stim.CircuitRepeatBlock) - if instr.name in ["OBSERVABLE_INCLUDE", "DETECTOR"]: - continue - circ.append(instr) - return Circuit.from_stim_program(circ) + """Return a copy of the circuit with all detector and observable annotations removed.""" + + def strip(circuit: stim.Circuit) -> stim.Circuit: + result = stim.Circuit() + for instr in circuit: + if isinstance(instr, stim.CircuitRepeatBlock): + stripped = strip(instr.body_copy()) + result.append(stim.CircuitRepeatBlock(instr.repeat_count, stripped)) + continue + if instr.name in ["OBSERVABLE_INCLUDE", "DETECTOR"]: + continue + result.append(instr) + return result + + return Circuit.from_stim_program(strip(self._stim_circ)) def detector_error_model( self, diff --git a/test/unit/test_circuit.py b/test/unit/test_circuit.py index 00615ee..13114eb 100644 --- a/test/unit/test_circuit.py +++ b/test/unit/test_circuit.py @@ -242,6 +242,24 @@ def test_circuit_without_annotations(): assert c_clean._stim_circ == stim.Circuit("H 0\nM 0") +def test_without_annotations_repeat_block(): + c = Circuit("H 0") + block = stim.CircuitRepeatBlock( + 3, stim.Circuit("CNOT 0 1\nM 0\nDETECTOR rec[-1]\nM 0") + ) + c.append(block) + c.append("OBSERVABLE_INCLUDE", [stim.target_rec(-1)], 0) + + c_clean = c.without_annotations() + # structure should be preserved + assert len(c_clean) == 2 + inst = c_clean[1] + assert isinstance(inst, stim.CircuitRepeatBlock) + assert inst.repeat_count == 3 + # annotations should be stripped inside the repeat block too + assert c_clean.flattened() == c.flattened().without_annotations() + + def test_circuit_eq(): c1 = Circuit("H 0") c2 = Circuit("H 0")