-
Notifications
You must be signed in to change notification settings - Fork 31
feat: Classical control flow with mid circuit measurements #347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 39 commits
259b318
253a450
5d38025
9729a19
90cc443
79ccfe9
0df15af
eb03649
b4c2ac8
4bf80a1
6f1ec22
18e1dc2
e3c8982
115612f
4456be9
0ae913b
61f5a46
4a859cb
96b9b2e
cabffb7
64ab4e9
6b6d207
ca29dc2
1142cb2
5d05c01
1ac7b48
325c353
96fa8ab
4549801
98a7f4a
29f5899
a3abdc4
87319ce
77283e2
8bca38f
f927119
a68744d
08f7329
8f99f71
77fe95b
2bf6c35
26babd4
7ae10d1
5a7810e
c2a2960
d8e88e4
56dab41
bfd2175
fe12d6e
189855f
f289b82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
| *.swp | ||
| *.idea | ||
| *.iml | ||
| .vscode/ | ||
| .kiro/ | ||
| build_files.tar.gz | ||
|
|
||
| .ycm_extra_conf.py | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1281,6 +1281,106 @@ def gate_type(self) -> str: | |
| return "gphase" | ||
|
|
||
|
|
||
| class Measure(GateOperation): | ||
| """ | ||
| Measurement operation that projects the state to a specific outcome. | ||
|
|
||
| This is used in branched simulation to apply measurement projections | ||
| when recalculating states from instruction sequences. | ||
| """ | ||
|
|
||
| def __init__(self, targets: Sequence[int], result: int = -1): | ||
| super().__init__(targets=targets) | ||
| self.result = result # The measurement outcome (0 or 1) | ||
|
|
||
| @property | ||
| def _base_matrix(self) -> np.ndarray: | ||
| """ | ||
| Return the projection matrix for the measurement outcome. | ||
| If result is -1 (unset), return identity (no projection). | ||
| """ | ||
| if self.result == -1: | ||
| return np.eye(2) | ||
| elif self.result == 0: | ||
| # Project to |0⟩⟨0| | ||
| return np.array([[1, 0], [0, 0]], dtype=complex) | ||
| elif self.result == 1: | ||
| # Project to |1⟩⟨1| | ||
| return np.array([[0, 0], [0, 1]], dtype=complex) | ||
| else: | ||
| return np.eye(2) | ||
|
|
||
| def apply(self, state: np.ndarray) -> np.ndarray: | ||
| if self.result == -1: | ||
| return state | ||
|
|
||
| # Apply projection matrix | ||
| projected_state = state.copy() | ||
|
|
||
| # For single qubit measurement, we need to project the appropriate amplitudes | ||
| if len(self._targets) == 1: | ||
| qubit_idx = self._targets[0] | ||
| n_qubits = int(np.log2(len(state))) | ||
|
|
||
| # Create mask for the target qubit | ||
| mask = 1 << (n_qubits - qubit_idx - 1) # Big-endian indexing | ||
|
|
||
| # Zero out amplitudes that don't match the measurement result | ||
| for i in range(len(projected_state)): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider njitting this as a follow up |
||
| qubit_value = (i & mask) >> (n_qubits - qubit_idx - 1) | ||
| if qubit_value != self.result: | ||
| projected_state[i] = 0 | ||
|
|
||
| # Normalize the state | ||
| norm = np.linalg.norm(projected_state) | ||
| if norm > 0: | ||
| projected_state /= norm | ||
|
|
||
| return projected_state | ||
|
|
||
|
|
||
| class Reset(GateOperation): | ||
| """ | ||
| Reset operation that sets desired target to 0 | ||
| """ | ||
|
|
||
| def __init__(self, targets: Sequence[int]): | ||
| super().__init__(targets=targets) | ||
|
|
||
| @property | ||
| def _base_matrix(self) -> np.ndarray: | ||
| raise NotImplementedError("Reset does not have a matrix implementation") | ||
|
|
||
| def apply(self, state: np.ndarray) -> np.ndarray: | ||
|
|
||
| # For single qubit measurement, we need to project the appropriate amplitudes | ||
| if len(self._targets) == 1: | ||
| qubit_idx = self._targets[0] | ||
| n_qubits = int(np.log2(len(state))) | ||
|
|
||
| # Create mask for the target qubit | ||
| mask = 1 << (n_qubits - qubit_idx - 1) # Big-endian indexing | ||
|
|
||
| for i in range(len(state)): | ||
| # Check if the qubit is in state 1 | ||
| qubit_value = (i & mask) >> (n_qubits - qubit_idx - 1) | ||
| if qubit_value == 1: | ||
| zero_index = i & ~mask | ||
|
|
||
| # Transfer the amplitude (with proper scaling) | ||
| state[zero_index] += state[i] | ||
|
rmshaffer marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Set the original amplitude to zero | ||
| state[i] = 0 | ||
|
|
||
| # Normalize the state | ||
| norm = np.linalg.norm(state) | ||
| if norm > 0: | ||
| state /= norm | ||
|
|
||
| return state | ||
|
|
||
|
|
||
| BRAKET_GATES = { | ||
| "i": Identity, | ||
| "h": Hadamard, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| from .circuit import Circuit | ||
| from .parser.openqasm_ast import ( | ||
| AccessControl, | ||
| AliasStatement, | ||
| ArrayLiteral, | ||
| ArrayReferenceType, | ||
| ArrayType, | ||
|
|
@@ -64,15 +65,20 @@ | |
| BitstringLiteral, | ||
| BitType, | ||
| BooleanLiteral, | ||
| BoolType, | ||
| Box, | ||
| BranchingStatement, | ||
| BreakStatement, | ||
| Cast, | ||
| ClassicalArgument, | ||
| ClassicalAssignment, | ||
| ClassicalDeclaration, | ||
| Concatenation, | ||
| ConstantDeclaration, | ||
| ContinueStatement, | ||
| DiscreteSet, | ||
| FloatLiteral, | ||
| FloatType, | ||
| ForInLoop, | ||
| FunctionCall, | ||
| GateModifierName, | ||
|
|
@@ -81,6 +87,7 @@ | |
| IndexedIdentifier, | ||
| IndexExpression, | ||
| IntegerLiteral, | ||
| IntType, | ||
| IODeclaration, | ||
| IOKeyword, | ||
| Pragma, | ||
|
|
@@ -101,11 +108,12 @@ | |
| SizeOf, | ||
| SubroutineDefinition, | ||
| SymbolLiteral, | ||
| UintType, | ||
| UnaryExpression, | ||
| WhileLoop, | ||
| ) | ||
| from .parser.openqasm_parser import parse | ||
| from .program_context import AbstractProgramContext, ProgramContext | ||
| from .program_context import AbstractProgramContext, ProgramContext, _BreakSignal, _ContinueSignal | ||
|
|
||
|
|
||
| class Interpreter: | ||
|
|
@@ -128,6 +136,8 @@ def __init__( | |
| ): | ||
| # context keeps track of all state | ||
| self.context = context or ProgramContext() | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.set_visitor(self.visit) | ||
| self.logger = logger or getLogger(__name__) | ||
| self._uses_advanced_language_features = False | ||
| self._warn_advanced_features = warn_advanced_features | ||
|
|
@@ -196,6 +206,12 @@ def _(self, node: ClassicalDeclaration) -> None: | |
| init_value = create_empty_array(node_type.dimensions) | ||
| elif isinstance(node_type, BitType) and node_type.size: | ||
| init_value = create_empty_array([node_type.size]) | ||
| elif isinstance(node_type, (IntType, UintType)): | ||
| init_value = IntegerLiteral(value=0) | ||
| elif isinstance(node_type, FloatType): | ||
| init_value = FloatLiteral(value=0.0) | ||
| elif isinstance(node_type, BoolType): | ||
| init_value = BooleanLiteral(value=False) | ||
| else: | ||
| init_value = None | ||
| self.context.declare_variable(node.identifier.name, node_type, init_value) | ||
|
|
@@ -265,7 +281,7 @@ def _(self, node: QubitDeclaration) -> None: | |
|
|
||
| @visit.register | ||
| def _(self, node: QuantumReset) -> None: | ||
| raise NotImplementedError("Reset not supported") | ||
| self.context.add_reset(list(self.context.get_qubits(self.visit(node.qubits)))) | ||
|
|
||
| @visit.register | ||
| def _(self, node: QuantumBarrier) -> None: | ||
|
|
@@ -528,7 +544,8 @@ def _(self, node: QuantumMeasurementStatement) -> None: | |
| self._uses_advanced_language_features = True | ||
| targets.extend(convert_range_def_to_range(self.visit(elem))) | ||
| case _: | ||
| targets.append(elem.value) | ||
| resolved = self.visit(elem) if isinstance(elem, Identifier) else elem | ||
| targets.append(resolved.value) | ||
|
|
||
| if not len(targets): | ||
| targets = None | ||
|
|
@@ -537,10 +554,26 @@ def _(self, node: QuantumMeasurementStatement) -> None: | |
| raise ValueError( | ||
| f"Number of qubits ({len(qubits)}) does not match number of provided classical targets ({len(targets)})" | ||
| ) | ||
| self.context.add_measure(qubits, targets) | ||
| if node.target and self.context.supports_midcircuit_measurement: | ||
| self.context.add_measure(qubits, targets, classical_destination=node.target) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point in the visitor, there is no way to know whether the measurement is being used as a MCM right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct, But I do agree it's a bit confusing that the logic relys on the presence of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code path that relies on the presence of classical_destination seems fragile and hard to maintain. Is there a more explicit way? |
||
| else: | ||
| self.context.add_measure(qubits, targets) | ||
|
|
||
| @visit.register | ||
| def _(self, node: ClassicalAssignment) -> None: | ||
| is_branched = getattr(self.context, "_is_branched", False) | ||
| if not is_branched or len(self.context._active_path_indices) <= 1: | ||
| self._execute_classical_assignment(node) | ||
| else: | ||
| # When multiple paths are active, evaluate the rvalue per-path | ||
| # so that expressions like ``y = x`` read from the correct path. | ||
| saved_active = list(self.context._active_path_indices) | ||
| for path_idx in saved_active: | ||
| self.context._active_path_indices = [path_idx] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is unclear why the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is to handle the edge case that I added: It does feels hacky this way.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to Aniket comment. We should avoid overwriting. Changing the logic that the program context always only get value from the first active branch may be the key to avoid that |
||
| self._execute_classical_assignment(deepcopy(node)) | ||
| self.context._active_path_indices = saved_active | ||
|
|
||
| def _execute_classical_assignment(self, node: ClassicalAssignment) -> None: | ||
| lvalue_name = get_identifier_name(node.lvalue) | ||
| if self.context.get_const(lvalue_name): | ||
| raise TypeError(f"Cannot update const value {lvalue_name}") | ||
|
|
@@ -565,29 +598,76 @@ def _(self, node: BitstringLiteral) -> ArrayLiteral: | |
| @visit.register | ||
| def _(self, node: BranchingStatement) -> None: | ||
| self._uses_advanced_language_features = True | ||
| condition = cast_to(BooleanLiteral, self.visit(node.condition)) | ||
| for statement in node.if_block if condition.value else node.else_block: | ||
| self.visit(statement) | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_branching_statement(node) | ||
|
yitchen-tim marked this conversation as resolved.
|
||
| else: | ||
| condition = cast_to(BooleanLiteral, self.visit(node.condition)) | ||
| if condition.value: | ||
| self.visit(node.if_block) | ||
| elif node.else_block: | ||
| self.visit(node.else_block) | ||
|
|
||
| @visit.register | ||
| def _(self, node: ForInLoop) -> None: | ||
| self._uses_advanced_language_features = True | ||
| index = self.visit(node.set_declaration) | ||
| if isinstance(index, RangeDefinition): | ||
| index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)] | ||
| # DiscreteSet | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_for_loop(node) | ||
| else: | ||
| index_values = index.values | ||
| for i in index_values: | ||
| with self.context.enter_scope(): | ||
| self.context.declare_variable(node.identifier.name, node.type, i) | ||
| self.visit(deepcopy(node.block)) | ||
| index = self.visit(node.set_declaration) | ||
| if isinstance(index, RangeDefinition): | ||
| index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)] | ||
| else: | ||
| index_values = index.values | ||
|
|
||
| loop_var_name = node.identifier.name | ||
| for i in index_values: | ||
| with self.context.enter_scope(): | ||
| self.context.declare_variable(loop_var_name, node.type, i) | ||
| try: | ||
| self.visit(deepcopy(node.block)) | ||
| except _BreakSignal: | ||
| break | ||
| except _ContinueSignal: | ||
| continue | ||
|
|
||
| @visit.register | ||
| def _(self, node: WhileLoop) -> None: | ||
| self._uses_advanced_language_features = True | ||
| while cast_to(BooleanLiteral, self.visit(deepcopy(node.while_condition))).value: | ||
| self.visit(deepcopy(node.block)) | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_while_loop(node) | ||
| else: | ||
| while cast_to(BooleanLiteral, self.visit(node.while_condition)).value: | ||
| try: | ||
| self.visit(deepcopy(node.block)) | ||
| except _BreakSignal: | ||
| break | ||
| except _ContinueSignal: | ||
| continue | ||
|
|
||
| @visit.register | ||
| def _(self, node: BreakStatement) -> None: | ||
| raise _BreakSignal() | ||
|
|
||
| @visit.register | ||
| def _(self, node: ContinueStatement) -> None: | ||
| raise _ContinueSignal() | ||
|
|
||
| @visit.register | ||
| def _(self, node: AliasStatement) -> None: | ||
| """Handle alias statements (let q1 = q, let combined = q1 ++ q2).""" | ||
|
speller26 marked this conversation as resolved.
|
||
| alias_name = node.target.name | ||
| if isinstance(node.value, Identifier): | ||
| # Simple alias: let q1 = q | ||
| self.context.qubit_mapping[alias_name] = self.context.get_qubits(node.value) | ||
| self.context.declare_qubit_alias(alias_name, node.value) | ||
| elif isinstance(node.value, Concatenation): | ||
| # Concatenation alias: let combined = q1 ++ q2 | ||
| lhs_qubits = tuple(self.context.get_qubits(node.value.lhs)) | ||
| rhs_qubits = tuple(self.context.get_qubits(node.value.rhs)) | ||
| self.context.qubit_mapping[alias_name] = lhs_qubits + rhs_qubits | ||
| self.context.declare_qubit_alias(alias_name, Identifier(alias_name)) | ||
| else: | ||
| raise NotImplementedError(f"Alias with {type(node.value).__name__} is not supported") | ||
|
|
||
| @visit.register | ||
| def _(self, node: Include) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this be done inplace?