diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index d22c118a8..d198be69e 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -139,6 +139,65 @@ def test_random_symbol_circuit_resolver_batch_shapes_and_types( isinstance(value, float) for value in resolver.param_dict.values())) + def test_random_circuit_resolver_batch_channels_present(self): + """Confirm channel ops appear in circuits when include_channels=True.""" + qubits = cirq.GridQubit.rect(1, 3) + channel_types = tuple(type(c) for c in util.get_supported_channels()) + + circuits, _ = util.random_circuit_resolver_batch(qubits, + batch_size=5, + n_moments=20, + include_channels=True) + + def has_channel(circuit): + return any( + isinstance(op.gate, channel_types) + for moment in circuit + for op in moment.operations) + + self.assertTrue( + any(has_channel(c) for c in circuits), + "Expected at least one channel operation across batch when " + "include_channels=True") + # Verify circuits containing channels are correctly serializable + # round-trip. + serialized = util.convert_to_tensor(circuits, + deterministic_proto_serialize=True) + deserialized = util.from_tensor(serialized) + self.assertAllEqual( + serialized, + util.convert_to_tensor(deserialized, + deterministic_proto_serialize=True)) + + def test_random_symbol_circuit_resolver_batch_channels_present(self): + """Confirm channel ops appear in circuits when include_channels=True.""" + qubits = cirq.GridQubit.rect(1, 3) + symbols = ['alpha', 'beta'] + channel_types = tuple(type(c) for c in util.get_supported_channels()) + + circuits, _ = util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size=5, n_moments=20, include_channels=True) + + def has_channel(circuit): + return any( + isinstance(op.gate, channel_types) + for moment in circuit + for op in moment.operations) + + self.assertTrue( + any(has_channel(c) for c in circuits), + "Expected at least one channel operation across batch when " + "include_channels=True") + # Verify circuits containing channels are correctly serializable + # round-trip. + serialized = util.convert_to_tensor(circuits, + deterministic_proto_serialize=True) + deserialized = util.from_tensor(serialized) + self.assertAllEqual( + serialized, + util.convert_to_tensor(deserialized, + deterministic_proto_serialize=True)) + @parameterized.parameters(_items_to_tensorize()) def test_convert_to_tensor(self, item): """Test that the convert_to_tensor function works correctly by manually