diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index f1484b0ce..4f6909d94 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -9,8 +9,7 @@ from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate from kafka.producer.transaction_manager import ProducerIdAndEpoch -from kafka.protocol.init_producer_id import InitProducerIdRequest -from kafka.protocol.produce import ProduceRequest +from kafka.protocol.new.producer import InitProducerIdRequest, ProduceRequest from kafka.structs import TopicPartition from kafka.version import __version__ diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index 6ebf748e9..dd20cd9b4 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -6,12 +6,11 @@ import threading import kafka.errors as Errors -from kafka.protocol.add_offsets_to_txn import AddOffsetsToTxnRequest -from kafka.protocol.add_partitions_to_txn import AddPartitionsToTxnRequest -from kafka.protocol.end_txn import EndTxnRequest -from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.init_producer_id import InitProducerIdRequest -from kafka.protocol.txn_offset_commit import TxnOffsetCommitRequest +from kafka.protocol.new.metadata import FindCoordinatorRequest +from kafka.protocol.new.producer import ( + AddOffsetsToTxnRequest, AddPartitionsToTxnRequest, + EndTxnRequest, InitProducerIdRequest, TxnOffsetCommitRequest, +) from kafka.structs import TopicPartition @@ -756,8 +755,8 @@ def __init__(self, transaction_manager, coord_type, coord_key): else: raise ValueError("Unrecognized coordinator type: %s" % (coord_type,)) self.request = FindCoordinatorRequest[version]( - coordinator_key=coord_key, - coordinator_type=coord_type_int8, + key=coord_key, + key_type=coord_type_int8, ) @property @@ -942,7 +941,7 @@ def handle_response(self, response): log.debug("Successfully added offsets for %s from consumer group %s to transaction.", tp, self.consumer_group_id) del self.transaction_manager._pending_txn_offset_commits[tp] - elif error in (errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError, Errors.RequestTimedOutError): + elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError, Errors.RequestTimedOutError): retriable_failure = True lookup_coordinator = True elif error is Errors.UnknownTopicOrPartitionError: diff --git a/test/integration/test_producer_integration.py b/test/integration/test_producer_integration.py index c7e6ad4c9..94b3668ff 100644 --- a/test/integration/test_producer_integration.py +++ b/test/integration/test_producer_integration.py @@ -194,6 +194,7 @@ def test_transactional_producer_offsets(kafka_broker): with producer_factory(bootstrap_servers=connect_str, transactional_id='testing') as producer: producer.init_transactions() producer.begin_transaction() + producer.send('transactional_test_topic', partition=0, value=b'msg1').get() producer.send_offsets_to_transaction(offsets, 'txn-test-group') producer.commit_transaction() diff --git a/test/test_sender.py b/test/test_sender.py index 08012adb3..a7c0020ae 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -13,7 +13,7 @@ import kafka.errors as Errors from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.producer.kafka import KafkaProducer -from kafka.protocol.produce import ProduceRequest +from kafka.protocol.new.producer import ProduceRequest from kafka.producer.future import FutureRecordMetadata from kafka.producer.producer_batch import ProducerBatch from kafka.producer.record_accumulator import RecordAccumulator @@ -64,7 +64,8 @@ def test_produce_request(sender, api_version, produce_version): magic = KafkaProducer.max_usable_produce_magic(api_version) batch = producer_batch(magic=magic) produce_request = sender._produce_request(0, 0, 0, [batch]) - assert isinstance(produce_request, ProduceRequest[produce_version]) + assert isinstance(produce_request, ProduceRequest) + assert produce_request.version == produce_version @pytest.mark.parametrize(("api_version", "produce_version"), [ @@ -81,7 +82,8 @@ def test_create_produce_requests(sender, api_version, produce_version): produce_requests_by_node = sender._create_produce_requests(batches_by_node) assert len(produce_requests_by_node) == 3 for node in range(3): - assert isinstance(produce_requests_by_node[node], ProduceRequest[produce_version]) + assert isinstance(produce_requests_by_node[node], ProduceRequest) + assert produce_requests_by_node[node].version == produce_version def test_complete_batch_success(sender):