Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 63 additions & 50 deletions src/tsim/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, stim_program_text: str = ""):
empty circuit.

"""
self._stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)).flattened()
self._stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text))

@classmethod
def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit:
Expand All @@ -52,7 +52,7 @@ def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit:

"""
c = cls.__new__(cls)
c._stim_circ = stim_circuit.flattened()
c._stim_circ = stim_circuit.copy()
return c

def append_from_stim_program_text(self, stim_program_text: str) -> None:
Expand All @@ -63,7 +63,6 @@ def append_from_stim_program_text(self, stim_program_text: str) -> None:
self._stim_circ.append_from_stim_program_text(
shorthand_to_stim(stim_program_text)
)
self._stim_circ = self._stim_circ.flattened()

@overload
def append(
Expand Down Expand Up @@ -158,8 +157,6 @@ def append(
self._stim_circ.append(name=name, targets=targets, arg=arg, tag=tag) # type: ignore
else:
self._stim_circ.append(name=name)
if isinstance(name, stim.CircuitRepeatBlock):
self._stim_circ = self._stim_circ.flattened()

@classmethod
def from_file(cls, filename: str) -> Circuit:
Expand All @@ -174,7 +171,7 @@ def from_file(cls, filename: str) -> Circuit:
"""
with open(filename, "r", encoding="utf-8") as f:
stim_program_text = f.read()
stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text)).flattened()
stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text))
return cls.from_stim_program(stim_circ)

def __repr__(self) -> str:
Expand Down Expand Up @@ -212,7 +209,6 @@ def __add__(self, other: Circuit | stim.Circuit) -> Circuit:
def __imul__(self, repetitions: int) -> Circuit:
"""Repeat this circuit in-place."""
self._stim_circ *= repetitions
self._stim_circ = self._stim_circ.flattened()
return self

def __mul__(self, repetitions: int) -> Circuit:
Expand All @@ -227,7 +223,7 @@ def __rmul__(self, repetitions: int) -> Circuit:
def __getitem__(
self,
index_or_slice: int,
) -> stim.CircuitInstruction: ...
) -> Union[stim.CircuitInstruction, stim.CircuitRepeatBlock]: ...

@overload
def __getitem__(
Expand Down Expand Up @@ -371,40 +367,49 @@ def num_ticks(
def pop(
self,
index: int = -1,
) -> stim.CircuitInstruction:
) -> Union[stim.CircuitInstruction, stim.CircuitRepeatBlock]:
"""Pops an operation from the end of the circuit, or at the given index.

Args:
index: Defaults to -1 (end of circuit). The index to pop from.

Returns:
The popped instruction.
The popped instruction or repeat block.

Raises:
IndexError: The given index is outside the bounds of the circuit.

"""
el = self._stim_circ.pop(index)
assert not isinstance(el, stim.CircuitRepeatBlock)
return el
return self._stim_circ.pop(index)

def copy(self) -> Circuit:
"""Create a copy of this circuit."""
return Circuit.from_stim_program(self._stim_circ.copy())

def flattened(self) -> Circuit:
"""Return a copy of the circuit with all repeat blocks expanded."""
return Circuit.from_stim_program(self._stim_circ.flattened())

def without_noise(self) -> Circuit:
"""Return a copy of the circuit with all noise removed."""
return Circuit.from_stim_program(self._stim_circ.without_noise())

def without_annotations(self) -> Circuit:
"""Return a copy of the circuit with all annotations removed."""
circ = stim.Circuit()
for instr in self._stim_circ:
assert not isinstance(instr, stim.CircuitRepeatBlock)
if instr.name in ["OBSERVABLE_INCLUDE", "DETECTOR"]:
continue
circ.append(instr)
return Circuit.from_stim_program(circ)
"""Return a copy of the circuit with all detector and observable annotations removed."""

def strip(circuit: stim.Circuit) -> stim.Circuit:
result = stim.Circuit()
for instr in circuit:
if isinstance(instr, stim.CircuitRepeatBlock):
stripped = strip(instr.body_copy())
result.append(stim.CircuitRepeatBlock(instr.repeat_count, stripped))
continue
if instr.name in ["OBSERVABLE_INCLUDE", "DETECTOR"]:
continue
result.append(instr)
return result

return Circuit.from_stim_program(strip(self._stim_circ))

def detector_error_model(
self,
Expand Down Expand Up @@ -688,35 +693,43 @@ def cast_to_stim(self) -> stim.Circuit:

def inverse(self) -> Circuit:
"""Return the inverse of the circuit."""
inv_stim_raw = self._stim_circ.inverse()

# Stim will only invert Clifford gates (and S[T] / S_DAG[T])
# Post-process to fix non-Clifford rotation gates stored as I[tag]
inv_stim = stim.Circuit()
for instr in inv_stim_raw:
assert not isinstance(instr, stim.CircuitRepeatBlock)
name = instr.name
tag = instr.tag
targets = [t.value for t in instr.targets_copy()]
args = instr.gate_args_copy()

if name == "I" and tag:
result = parse_parametric_tag(tag)
if result is not None:
gate_name, params = result
if gate_name == "U3":
# U3(θ, φ, λ)⁻¹ = U3(-θ, -λ, -φ)
theta = float(-params["theta"])
phi = float(-params["lambda"])
lam = float(-params["phi"])
new_tag = f"U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)"
else:
theta = float(-params["theta"])
new_tag = f"{gate_name}(theta={theta}*pi)"
inv_stim.append("I", targets, args, tag=new_tag)
continue

# All other instructions are correct from stim's inverse
inv_stim.append(instr)
def fix_tags(circuit: stim.Circuit) -> stim.Circuit:
# Stim only inverts Clifford gates (and S[T] / S_DAG[T]).
# Non-Clifford rotations stored as I[tag] need their parameters
# negated (and reordered for U3).
result = stim.Circuit()
for instr in circuit:
if isinstance(instr, stim.CircuitRepeatBlock):
fixed = fix_tags(instr.body_copy())
result.append(stim.CircuitRepeatBlock(instr.repeat_count, fixed))
continue

name = instr.name
tag = instr.tag
targets = [t.value for t in instr.targets_copy()]
args = instr.gate_args_copy()

if name == "I" and tag:
parsed = parse_parametric_tag(tag)
if parsed is not None:
gate_name, params = parsed
if gate_name == "U3":
# U3(θ, φ, λ)⁻¹ = U3(-θ, -λ, -φ)
theta = float(-params["theta"])
phi = float(-params["lambda"])
lam = float(-params["phi"])
new_tag = (
f"U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)"
)
else:
theta = float(-params["theta"])
new_tag = f"{gate_name}(theta={theta}*pi)"
result.append("I", targets, args, tag=new_tag)
continue

result.append(instr)
return result

inv_stim = fix_tags(self._stim_circ.inverse())
return Circuit.from_stim_program(inv_stim)
19 changes: 16 additions & 3 deletions src/tsim/utils/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,24 @@ def tagged_gates_to_placeholder(
of the I_ERROR placeholder gates to the actual gate names.

"""
modified_circ = stim.Circuit()
replace_dict: dict[float, GateLabel] = {}
modified_circ = _replace_tagged_gates(circuit, replace_dict)
return modified_circ, replace_dict


def _replace_tagged_gates(
circuit: stim.Circuit,
replace_dict: dict[float, GateLabel],
) -> stim.Circuit:
modified_circ = stim.Circuit()

for instr in circuit:
assert not isinstance(instr, stim.CircuitRepeatBlock)
if isinstance(instr, stim.CircuitRepeatBlock):
modified_body = _replace_tagged_gates(instr.body_copy(), replace_dict)
modified_circ.append(
stim.CircuitRepeatBlock(instr.repeat_count, modified_body)
)
continue

# Handle T gates (S[T] and S_DAG[T])
if instr.tag == "T" and instr.name in ["S", "S_DAG"]:
Expand Down Expand Up @@ -255,7 +268,7 @@ def tagged_gates_to_placeholder(
continue

modified_circ.append(instr)
return modified_circ, replace_dict
return modified_circ


def render_svg(
Expand Down
3 changes: 1 addition & 2 deletions src/tsim/utils/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def _transform_circuit(
observables: list[list[int]] | None = None,
) -> stim.Circuit:
"""Expand and duplicate instructions with broadcast targets for encoding."""
stim_circ = tsim.Circuit(program_text)._stim_circ
stim_circ = tsim.Circuit(program_text)._stim_circ
stim_circ = tsim.Circuit(program_text)._stim_circ.flattened()
mod_circ = stim.Circuit()

for instr in stim_circ:
Expand Down
31 changes: 31 additions & 0 deletions test/unit/core/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,37 @@ def test_chain_multiple_qubits(self):
assert_allclose(probs[4], 0.2, rtol=1e-5) # Third error


class TestParseWithRepeatBlocks:
"""Tests for parsing circuits that contain REPEAT blocks."""

def test_parse_circuit_with_repeat_block(self):
"""parse_stim_circuit should flatten repeat blocks transparently."""
flat_circuit = stim.Circuit("H 0\nCNOT 0 1\nH 0\nCNOT 0 1\nH 0\nCNOT 0 1")
repeat_circuit = stim.Circuit("REPEAT 3 {\n H 0\n CNOT 0 1\n}")

b_flat = parse_stim_circuit(flat_circuit)
b_repeat = parse_stim_circuit(repeat_circuit)

assert len(b_flat.graph.vertices()) == len(b_repeat.graph.vertices())
assert list(b_flat.graph.edges()) == list(b_repeat.graph.edges())

def test_parse_repeat_block_with_measurements(self):
"""Repeat blocks containing measurements should parse correctly."""
circuit = stim.Circuit("REPEAT 3 {\n H 0\n M 0\n}")
b = parse_stim_circuit(circuit)
assert len(b.rec) == 3

def test_parse_nested_repeat_blocks(self):
"""Nested repeat blocks should be fully flattened by the parser."""
circuit = stim.Circuit("REPEAT 2 {\n REPEAT 3 {\n H 0\n }\n}")
flat = stim.Circuit("H 0\nH 0\nH 0\nH 0\nH 0\nH 0")

b_nested = parse_stim_circuit(circuit)
b_flat = parse_stim_circuit(flat)

assert len(b_nested.graph.vertices()) == len(b_flat.graph.vertices())


class TestCorrelatedErrorState:
"""Tests for correlated error state management."""

Expand Down
Loading
Loading