diff --git a/kafka/coordinator/assignors/abstract.py b/kafka/coordinator/assignors/abstract.py index 3cdc2cace..04d84d285 100644 --- a/kafka/coordinator/assignors/abstract.py +++ b/kafka/coordinator/assignors/abstract.py @@ -43,7 +43,7 @@ def metadata(self, topics): pass @abc.abstractmethod - def on_assignment(self, assignment): + def on_assignment(self, assignment, generation): """Callback that runs on each assignment. This method can be used to update internal state, if any, of the @@ -51,5 +51,6 @@ def on_assignment(self, assignment): Arguments: assignment (MemberAssignment): the member's assignment + generation (int): generation id of assignment """ pass diff --git a/kafka/coordinator/assignors/range.py b/kafka/coordinator/assignors/range.py index 307ae0b76..dc45f8a9a 100644 --- a/kafka/coordinator/assignors/range.py +++ b/kafka/coordinator/assignors/range.py @@ -74,5 +74,5 @@ def metadata(cls, topics): return ConsumerProtocolMemberMetadata_v0(cls.version, list(topics), b'') @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment, generation): pass diff --git a/kafka/coordinator/assignors/roundrobin.py b/kafka/coordinator/assignors/roundrobin.py index f73a10679..83d98ee58 100644 --- a/kafka/coordinator/assignors/roundrobin.py +++ b/kafka/coordinator/assignors/roundrobin.py @@ -93,5 +93,5 @@ def metadata(cls, topics): return ConsumerProtocolMemberMetadata_v0(cls.version, list(topics), b'') @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment, generation): pass diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/kafka/coordinator/assignors/sticky/sticky_assignor.py index bb6c10fe2..158a21da7 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -666,21 +666,12 @@ def _metadata(cls, topics, member_assignment_partitions, generation=-1): return ConsumerProtocolMemberMetadata_v0(cls.version, list(topics), user_data) @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment, generation): """Callback that runs on each assignment. Updates assignor's state. Arguments: assignment: MemberAssignment """ - log.debug("On assignment: assignment={}".format(assignment)) + log.debug(f"On assignment: assignment={assignment}, generation={generation}") cls.member_assignment = assignment.partitions() - - @classmethod - def on_generation_assignment(cls, generation): - """Callback that runs on each assignment. Updates assignor's generation id. - - Arguments: - generation: generation id - """ - log.debug("On generation assignment: generation={}".format(generation)) cls.generation = generation diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index 3f834ccf4..a9335a1a5 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -245,9 +245,7 @@ def _on_join_complete(self, generation, member_id, protocol, # give the assignor a chance to update internal state # based on the received assignment - assignor.on_assignment(assignment) - if assignor.name == 'sticky': - assignor.on_generation_assignment(generation) + assignor.on_assignment(assignment, generation) # reschedule the auto commit starting from now self.next_auto_commit_deadline = time.monotonic() + self.auto_commit_interval diff --git a/test/test_coordinator.py b/test/test_coordinator.py index 542a7b313..3af922696 100644 --- a/test/test_coordinator.py +++ b/test/test_coordinator.py @@ -136,9 +136,10 @@ def test_join_complete(mocker, coordinator): mocker.spy(assignor, 'on_assignment') assert assignor.on_assignment.call_count == 0 assignment = ConsumerProtocolMemberAssignment_v0(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete(0, 'member-foo', 'roundrobin', assignment.encode()) + generation = 12 + coordinator._on_join_complete(generation, 'member-foo', 'roundrobin', assignment.encode()) assert assignor.on_assignment.call_count == 1 - assignor.on_assignment.assert_called_with(assignment) + assignor.on_assignment.assert_called_with(assignment, generation) def test_join_complete_with_sticky_assignor(mocker, coordinator): @@ -146,15 +147,12 @@ def test_join_complete_with_sticky_assignor(mocker, coordinator): assignor = StickyPartitionAssignor() coordinator.config['assignors'] = (assignor,) mocker.spy(assignor, 'on_assignment') - mocker.spy(assignor, 'on_generation_assignment') assert assignor.on_assignment.call_count == 0 - assert assignor.on_generation_assignment.call_count == 0 + generation = 3 assignment = ConsumerProtocolMemberAssignment_v0(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete(0, 'member-foo', 'sticky', assignment.encode()) + coordinator._on_join_complete(generation, 'member-foo', 'sticky', assignment.encode()) assert assignor.on_assignment.call_count == 1 - assert assignor.on_generation_assignment.call_count == 1 - assignor.on_assignment.assert_called_with(assignment) - assignor.on_generation_assignment.assert_called_with(0) + assignor.on_assignment.assert_called_with(assignment, generation) def test_subscription_listener(mocker, coordinator):