Skip to content

Commit 30f767c

Browse files
committed
use api_version() to select all api versions in consumer/coordinator/producer
1 parent 8a2c91e commit 30f767c

File tree

7 files changed

+57
-85
lines changed

7 files changed

+57
-85
lines changed

kafka/consumer/fetcher.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -569,10 +569,8 @@ def _send_offset_request(self, node_id, timestamps):
569569
data = (tp.partition, timestamp, 1)
570570
by_topic[tp.topic].append(data)
571571

572-
if self.config['api_version'] >= (0, 10, 1):
573-
request = OffsetRequest[1](-1, list(six.iteritems(by_topic)))
574-
else:
575-
request = OffsetRequest[0](-1, list(six.iteritems(by_topic)))
572+
version = self._client.api_version(OffsetRequest, max_version=1)
573+
request = OffsetRequest[version](-1, list(six.iteritems(by_topic)))
576574

577575
# Client returns a future that only fails on network issues
578576
# so create a separate future and attach a callback to update it
@@ -702,16 +700,7 @@ def _create_fetch_requests(self):
702700
log.log(0, "Skipping fetch for partition %s because there is an inflight request to node %s",
703701
partition, node_id)
704702

705-
if self.config['api_version'] >= (0, 11):
706-
version = 4
707-
elif self.config['api_version'] >= (0, 10, 1):
708-
version = 3
709-
elif self.config['api_version'] >= (0, 10, 0):
710-
version = 2
711-
elif self.config['api_version'] == (0, 9):
712-
version = 1
713-
else:
714-
version = 0
703+
version = self._client.api_version(FetchRequest, max_version=4)
715704
requests = {}
716705
for node_id, partition_data in six.iteritems(fetchable):
717706
if version < 3:

kafka/coordinator/base.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -452,25 +452,18 @@ def _send_join_group_request(self):
452452
(protocol, metadata if isinstance(metadata, bytes) else metadata.encode())
453453
for protocol, metadata in self.group_protocols()
454454
]
455-
if self.config['api_version'] < (0, 9):
455+
version = self._client.api_version(JoinGroupRequest, max_version=2)
456+
if not version:
456457
raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers')
457-
elif (0, 9) <= self.config['api_version'] < (0, 10, 1):
458-
request = JoinGroupRequest[0](
458+
elif version == 0:
459+
request = JoinGroupRequest[version](
459460
self.group_id,
460461
self.config['session_timeout_ms'],
461462
self._generation.member_id,
462463
self.protocol_type(),
463464
member_metadata)
464-
elif (0, 10, 1) <= self.config['api_version'] < (0, 11):
465-
request = JoinGroupRequest[1](
466-
self.group_id,
467-
self.config['session_timeout_ms'],
468-
self.config['max_poll_interval_ms'],
469-
self._generation.member_id,
470-
self.protocol_type(),
471-
member_metadata)
472465
else:
473-
request = JoinGroupRequest[2](
466+
request = JoinGroupRequest[version](
474467
self.group_id,
475468
self.config['session_timeout_ms'],
476469
self.config['max_poll_interval_ms'],
@@ -562,7 +555,7 @@ def _handle_join_group_response(self, future, send_time, response):
562555

563556
def _on_join_follower(self):
564557
# send follower's sync group with an empty assignment
565-
version = 0 if self.config['api_version'] < (0, 11) else 1
558+
version = self._client.api_version(SyncGroupRequest, max_version=1)
566559
request = SyncGroupRequest[version](
567560
self.group_id,
568561
self._generation.generation_id,
@@ -590,7 +583,7 @@ def _on_join_leader(self, response):
590583
except Exception as e:
591584
return Future().failure(e)
592585

593-
version = 0 if self.config['api_version'] < (0, 11) else 1
586+
version = self._client.api_version(SyncGroupRequest, max_version=1)
594587
request = SyncGroupRequest[version](
595588
self.group_id,
596589
self._generation.generation_id,
@@ -744,7 +737,7 @@ def _start_heartbeat_thread(self):
744737
self._heartbeat_thread.start()
745738

746739
def _close_heartbeat_thread(self):
747-
if self._heartbeat_thread is not None:
740+
if hasattr(self, '_heartbeat_thread') and self._heartbeat_thread is not None:
748741
log.info('Stopping heartbeat thread')
749742
try:
750743
self._heartbeat_thread.close()
@@ -771,7 +764,7 @@ def maybe_leave_group(self):
771764
# this is a minimal effort attempt to leave the group. we do not
772765
# attempt any resending if the request fails or times out.
773766
log.info('Leaving consumer group (%s).', self.group_id)
774-
version = 0 if self.config['api_version'] < (0, 11) else 1
767+
version = self._client.api_version(LeaveGroupRequest, max_version=1)
775768
request = LeaveGroupRequest[version](self.group_id, self._generation.member_id)
776769
future = self._client.send(self.coordinator_id, request)
777770
future.add_callback(self._handle_leave_group_response)
@@ -799,7 +792,7 @@ def _send_heartbeat_request(self):
799792
e = Errors.NodeNotReadyError(self.coordinator_id)
800793
return Future().failure(e)
801794

802-
version = 0 if self.config['api_version'] < (0, 11) else 1
795+
version = self._client.api_version(HeartbeatRequest, max_version=1)
803796
request = HeartbeatRequest[version](self.group_id,
804797
self._generation.generation_id,
805798
self._generation.member_id)

kafka/coordinator/consumer.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,13 @@ def _send_offset_commit_request(self, offsets):
582582
if self.config['api_version'] >= (0, 9) and generation is None:
583583
return Future().failure(Errors.CommitFailedError())
584584

585-
if self.config['api_version'] >= (0, 9):
586-
request = OffsetCommitRequest[2](
585+
version = self._client.api_version(OffsetCommitRequest, max_version=2)
586+
if version == 2:
587+
request = OffsetCommitRequest[version](
587588
self.group_id,
588589
generation.generation_id,
589590
generation.member_id,
590-
OffsetCommitRequest[2].DEFAULT_RETENTION_TIME,
591+
OffsetCommitRequest[version].DEFAULT_RETENTION_TIME,
591592
[(
592593
topic, [(
593594
partition,
@@ -596,8 +597,8 @@ def _send_offset_commit_request(self, offsets):
596597
) for partition, offset in six.iteritems(partitions)]
597598
) for topic, partitions in six.iteritems(offset_data)]
598599
)
599-
elif self.config['api_version'] >= (0, 8, 2):
600-
request = OffsetCommitRequest[1](
600+
elif version == 1:
601+
request = OffsetCommitRequest[version](
601602
self.group_id, -1, '',
602603
[(
603604
topic, [(
@@ -608,8 +609,8 @@ def _send_offset_commit_request(self, offsets):
608609
) for partition, offset in six.iteritems(partitions)]
609610
) for topic, partitions in six.iteritems(offset_data)]
610611
)
611-
elif self.config['api_version'] >= (0, 8, 1):
612-
request = OffsetCommitRequest[0](
612+
elif version == 0:
613+
request = OffsetCommitRequest[version](
613614
self.group_id,
614615
[(
615616
topic, [(
@@ -731,16 +732,11 @@ def _send_offset_fetch_request(self, partitions):
731732
for tp in partitions:
732733
topic_partitions[tp.topic].add(tp.partition)
733734

734-
if self.config['api_version'] >= (0, 8, 2):
735-
request = OffsetFetchRequest[1](
736-
self.group_id,
737-
list(topic_partitions.items())
738-
)
739-
else:
740-
request = OffsetFetchRequest[0](
741-
self.group_id,
742-
list(topic_partitions.items())
743-
)
735+
version = self._client.api_version(OffsetFetchRequest, max_version=1)
736+
request = OffsetFetchRequest[version](
737+
self.group_id,
738+
list(topic_partitions.items())
739+
)
744740

745741
# send the request with a callback
746742
future = Future()

kafka/producer/sender.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -301,31 +301,14 @@ def _produce_request(self, node_id, acks, timeout, batches):
301301
buf = batch.records.buffer()
302302
produce_records_by_partition[topic][partition] = buf
303303

304-
kwargs = {}
305-
if self.config['api_version'] >= (2, 1):
306-
version = 7
307-
elif self.config['api_version'] >= (2, 0):
308-
version = 6
309-
elif self.config['api_version'] >= (1, 1):
310-
version = 5
311-
elif self.config['api_version'] >= (1, 0):
312-
version = 4
313-
elif self.config['api_version'] >= (0, 11):
314-
version = 3
315-
kwargs = dict(transactional_id=None)
316-
elif self.config['api_version'] >= (0, 10, 0):
317-
version = 2
318-
elif self.config['api_version'] == (0, 9):
319-
version = 1
320-
else:
321-
version = 0
304+
version = self._client.api_version(ProduceRequest, max_version=7)
305+
# TODO: support transactional_id
322306
return ProduceRequest[version](
323307
required_acks=acks,
324308
timeout=timeout,
325309
topics=[(topic, list(partition_info.items()))
326310
for topic, partition_info
327311
in six.iteritems(produce_records_by_partition)],
328-
**kwargs
329312
)
330313

331314
def wakeup(self):

test/test_coordinator.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import kafka.errors as Errors
1818
from kafka.future import Future
1919
from kafka.metrics import Metrics
20+
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
2021
from kafka.protocol.commit import (
2122
OffsetCommitRequest, OffsetCommitResponse,
2223
OffsetFetchRequest, OffsetFetchResponse)
@@ -41,8 +42,9 @@ def test_init(client, coordinator):
4142

4243

4344
@pytest.mark.parametrize("api_version", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)])
44-
def test_autocommit_enable_api_version(client, api_version):
45-
coordinator = ConsumerCoordinator(client, SubscriptionState(),
45+
def test_autocommit_enable_api_version(conn, api_version):
46+
coordinator = ConsumerCoordinator(KafkaClient(api_version=api_version),
47+
SubscriptionState(),
4648
Metrics(),
4749
enable_auto_commit=True,
4850
session_timeout_ms=30000, # session_timeout_ms and max_poll_interval_ms
@@ -86,8 +88,13 @@ def test_group_protocols(coordinator):
8688

8789

8890
@pytest.mark.parametrize('api_version', [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)])
89-
def test_pattern_subscription(coordinator, api_version):
90-
coordinator.config['api_version'] = api_version
91+
def test_pattern_subscription(conn, api_version):
92+
coordinator = ConsumerCoordinator(KafkaClient(api_version=api_version),
93+
SubscriptionState(),
94+
Metrics(),
95+
api_version=api_version,
96+
session_timeout_ms=10000,
97+
max_poll_interval_ms=10000)
9198
coordinator._subscription.subscribe(pattern='foo')
9299
assert coordinator._subscription.subscription == set([])
93100
assert coordinator._metadata_snapshot == coordinator._build_metadata_snapshot(coordinator._subscription, {})
@@ -436,7 +443,7 @@ def test_send_offset_commit_request_fail(mocker, patched_coord, offsets):
436443
def test_send_offset_commit_request_versions(patched_coord, offsets,
437444
api_version, req_type):
438445
expect_node = 0
439-
patched_coord.config['api_version'] = api_version
446+
patched_coord._client._api_versions = BROKER_API_VERSIONS[api_version]
440447

441448
patched_coord._send_offset_commit_request(offsets)
442449
(node, request), _ = patched_coord._client.send.call_args
@@ -532,7 +539,7 @@ def test_send_offset_fetch_request_versions(patched_coord, partitions,
532539
api_version, req_type):
533540
# assuming fixture sets coordinator=0, least_loaded_node=1
534541
expect_node = 0
535-
patched_coord.config['api_version'] = api_version
542+
patched_coord._client._api_versions = BROKER_API_VERSIONS[api_version]
536543

537544
patched_coord._send_offset_fetch_request(partitions)
538545
(node, request), _ = patched_coord._client.send.call_args

test/test_fetcher.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import kafka.errors as Errors
1717
from kafka.future import Future
1818
from kafka.metrics import Metrics
19+
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
1920
from kafka.protocol.fetch import FetchRequest, FetchResponse
2021
from kafka.protocol.offset import OffsetResponse
2122
from kafka.errors import (
@@ -27,8 +28,8 @@
2728

2829

2930
@pytest.fixture
30-
def client(mocker):
31-
return mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9)))
31+
def client():
32+
return KafkaClient(bootstrap_servers=(), api_version=(0, 9))
3233

3334

3435
@pytest.fixture
@@ -81,6 +82,8 @@ def test_send_fetches(fetcher, topic, mocker):
8182
mocker.patch.object(fetcher, '_create_fetch_requests',
8283
return_value=dict(enumerate(fetch_requests)))
8384

85+
mocker.patch.object(fetcher._client, 'ready', return_value=True)
86+
mocker.patch.object(fetcher._client, 'send')
8487
ret = fetcher.send_fetches()
8588
for node, request in enumerate(fetch_requests):
8689
fetcher._client.send.assert_any_call(node, request, wakeup=False)
@@ -91,14 +94,14 @@ def test_send_fetches(fetcher, topic, mocker):
9194
((0, 10, 1), 3),
9295
((0, 10, 0), 2),
9396
((0, 9), 1),
94-
((0, 8), 0)
97+
((0, 8, 2), 0)
9598
])
9699
def test_create_fetch_requests(fetcher, mocker, api_version, fetch_version):
97-
fetcher._client.in_flight_request_count.return_value = 0
98-
fetcher.config['api_version'] = api_version
100+
fetcher._client._api_versions = BROKER_API_VERSIONS[api_version]
101+
mocker.patch.object(fetcher._client.cluster, "leader_for_partition", return_value=0)
99102
by_node = fetcher._create_fetch_requests()
100103
requests = by_node.values()
101-
assert all([isinstance(r, FetchRequest[fetch_version]) for r in requests])
104+
assert set([r.API_VERSION for r in requests]) == set([fetch_version])
102105

103106

104107
def test_update_fetch_positions(fetcher, topic, mocker):
@@ -485,6 +488,7 @@ def test__parse_fetched_data__not_leader(fetcher, topic, mocker):
485488
tp, 0, 0, [NotLeaderForPartitionError.errno, -1, None],
486489
mocker.MagicMock()
487490
)
491+
mocker.patch.object(fetcher._client.cluster, 'request_update')
488492
partition_record = fetcher._parse_fetched_data(completed_fetch)
489493
assert partition_record is None
490494
fetcher._client.cluster.request_update.assert_called_with()
@@ -497,6 +501,7 @@ def test__parse_fetched_data__unknown_tp(fetcher, topic, mocker):
497501
tp, 0, 0, [UnknownTopicOrPartitionError.errno, -1, None],
498502
mocker.MagicMock()
499503
)
504+
mocker.patch.object(fetcher._client.cluster, 'request_update')
500505
partition_record = fetcher._parse_fetched_data(completed_fetch)
501506
assert partition_record is None
502507
fetcher._client.cluster.request_update.assert_called_with()

test/test_sender.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from kafka.client_async import KafkaClient
88
from kafka.cluster import ClusterMetadata
99
from kafka.metrics import Metrics
10+
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
1011
from kafka.protocol.produce import ProduceRequest
1112
from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch
1213
from kafka.producer.sender import Sender
@@ -15,10 +16,8 @@
1516

1617

1718
@pytest.fixture
18-
def client(mocker):
19-
_cli = mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9)))
20-
_cli.cluster = mocker.Mock(spec=ClusterMetadata())
21-
return _cli
19+
def client():
20+
return KafkaClient(bootstrap_servers=(), api_version=(0, 9))
2221

2322

2423
@pytest.fixture
@@ -32,7 +31,7 @@ def metrics():
3231

3332

3433
@pytest.fixture
35-
def sender(client, accumulator, metrics):
34+
def sender(client, accumulator, metrics, mocker):
3635
return Sender(client, client.cluster, accumulator, metrics)
3736

3837

@@ -42,7 +41,7 @@ def sender(client, accumulator, metrics):
4241
((0, 8, 0), 0)
4342
])
4443
def test_produce_request(sender, mocker, api_version, produce_version):
45-
sender.config['api_version'] = api_version
44+
sender._client._api_versions = BROKER_API_VERSIONS[api_version]
4645
tp = TopicPartition('foo', 0)
4746
buffer = io.BytesIO()
4847
records = MemoryRecordsBuilder(

0 commit comments

Comments
 (0)