Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions kafka/coordinator/assignors/sticky/sticky_assignor.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,12 @@ class StickyPartitionAssignor(AbstractPartitionAssignor):
name = "sticky"
version = 0

member_assignment = None
generation = DEFAULT_GENERATION_ID
def __init__(self):
self.member_assignment = None
self.generation = self.DEFAULT_GENERATION_ID
self._latest_partition_movements = None

_latest_partition_movements = None

@classmethod
def assign(cls, cluster, members):
def assign(self, cluster, members):
"""Performs group assignment given cluster metadata and member subscriptions

Arguments:
Expand All @@ -582,18 +581,19 @@ def assign(cls, cluster, members):
dict: {member_id: ConsumerProtocolAssignment}
"""
members_metadata = {
member.member_id: cls.parse_member_metadata(member.metadata)
member.member_id: self.parse_member_metadata(member.metadata)
for member in members
}
executor = StickyAssignmentExecutor(cluster, members_metadata)
executor.perform_initial_assignment()
executor.balance()

cls._latest_partition_movements = executor.partition_movements
# store for tests
self._latest_partition_movements = executor.partition_movements

assignment = {
member.member_id: ConsumerProtocolAssignment(
cls.version, sorted(executor.get_final_assignment(member.member_id)), b'')
self.version, sorted(executor.get_final_assignment(member.member_id)), b'')
for member in members
}
return assignment
Expand Down Expand Up @@ -635,31 +635,29 @@ def parse_member_metadata(cls, metadata):
partitions=member_partitions, generation=decoded_user_data.generation, subscription=metadata.topics
)

@classmethod
def metadata(cls, topics):
return cls._metadata(topics, cls.member_assignment, cls.generation)
def metadata(self, topics):
return self._metadata(topics, self.member_assignment, self.generation)

@classmethod
def _metadata(cls, topics, member_assignment_partitions, generation=-1):
if member_assignment_partitions is None:
log.debug("No member assignment available")
user_data = b''
else:
log.debug("Member assignment is available, generating the metadata: generation {}".format(cls.generation))
log.debug("Member assignment is available, generating the metadata: generation {}".format(generation))
partitions_by_topic = defaultdict(list)
for topic_partition in member_assignment_partitions:
partitions_by_topic[topic_partition.topic].append(topic_partition.partition)
data = StickyAssignorUserData(list(partitions_by_topic.items()), generation)
user_data = data.encode()
return ConsumerProtocolSubscription(cls.version, list(topics), user_data)

@classmethod
def on_assignment(cls, assignment, generation):
def on_assignment(self, assignment, generation):
"""Callback that runs on each assignment. Updates assignor's state.

Arguments:
assignment: MemberAssignment
"""
log.debug(f"On assignment: assignment={assignment}, generation={generation}")
cls.member_assignment = assignment.partitions()
cls.generation = generation
self.member_assignment = assignment.partitions()
self.generation = generation
21 changes: 11 additions & 10 deletions kafka/coordinator/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __init__(self, client, subscription, **configs):
else:
self._consumer_sensors = None

self._assignors = {}
for klass in self.config['assignors']:
assignor = klass()
self._assignors[assignor.name] = assignor
self._cluster.request_update()
self._cluster.add_listener(WeakMethod(self._handle_metadata_update))

Expand Down Expand Up @@ -166,9 +170,9 @@ def group_protocols(self):
# best I've got for now.
self._joined_subscription = set(self._subscription.subscription)
metadata_list = []
for assignor in self.config['assignors']:
metadata = assignor.metadata(self._joined_subscription)
group_protocol = (assignor.name, metadata)
for assignor in self._assignors:
metadata = self._assignors[assignor].metadata(self._joined_subscription)
group_protocol = (assignor, metadata)
metadata_list.append(group_protocol)
return metadata_list

Expand Down Expand Up @@ -221,10 +225,7 @@ def _build_metadata_snapshot(self, subscription, cluster):
return metadata_snapshot

def _lookup_assignor(self, name):
for assignor in self.config['assignors']:
if assignor.name == name:
return assignor
return None
return self._assignors.get(name, None)

def _on_join_complete(self, generation, member_id, protocol,
member_assignment_bytes):
Expand Down Expand Up @@ -326,9 +327,9 @@ def time_to_next_poll(self):
return min(self.next_auto_commit_deadline - time.monotonic(),
self.time_to_next_heartbeat())

def _perform_assignment(self, leader_id, assignment_strategy, members):
assignor = self._lookup_assignor(assignment_strategy)
assert assignor, 'Invalid assignment protocol: %s' % (assignment_strategy,)
def _perform_assignment(self, leader_id, protocol_name, members):
assignor = self._lookup_assignor(protocol_name)
assert assignor, 'Invalid assignment protocol: %s' % (protocol_name,)
all_subscribed_topics = set()
for member in members:
member.metadata = ConsumerProtocolSubscription.decode(member.metadata)
Expand Down
Loading
Loading