diff --git a/group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/assignor/StickyTaskAssignor.java b/group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/assignor/StickyTaskAssignor.java index f455bb577eb68..284d9b0f16d9f 100644 --- a/group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/assignor/StickyTaskAssignor.java +++ b/group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/assignor/StickyTaskAssignor.java @@ -20,12 +20,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Optional; +import java.util.PriorityQueue; import java.util.Set; import java.util.stream.Collectors; @@ -115,7 +117,7 @@ private void initialize(final GroupSpec groupSpec, final TopologyDescriber topol Set partitionNoSet = entry.getValue(); for (int partitionNo : partitionNoSet) { TaskId taskId = new TaskId(entry.getKey(), partitionNo); - localState.standbyTaskToPrevMember.putIfAbsent(taskId, new HashSet<>()); + localState.standbyTaskToPrevMember.putIfAbsent(taskId, new ArrayList<>()); localState.standbyTaskToPrevMember.get(taskId).add(member); } } @@ -171,8 +173,9 @@ private void assignActive(final Set activeTasks) { final TaskId task = it.next(); final Member prevMember = localState.activeTaskToPrevMember.get(task); if (prevMember != null && hasUnfulfilledQuota(prevMember)) { - localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true); - updateHelpers(prevMember, true); + ProcessState processState = localState.processIdToState.get(prevMember.processId); + processState.addTask(prevMember.memberId, task, true); + maybeUpdateTasksPerMember(processState.activeTaskCount()); it.remove(); } } @@ -180,29 +183,33 @@ private void assignActive(final Set activeTasks) { // 2. re-assigning tasks to clients that previously have seen the same task (as standby task) for (Iterator it = activeTasks.iterator(); it.hasNext();) { final TaskId task = it.next(); - final Set prevMembers = localState.standbyTaskToPrevMember.get(task); - final Member prevMember = findMemberWithLeastLoad(prevMembers, task, true); + final ArrayList prevMembers = localState.standbyTaskToPrevMember.get(task); + final Member prevMember = findPrevMemberWithLeastLoad(prevMembers, null); if (prevMember != null && hasUnfulfilledQuota(prevMember)) { - localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true); - updateHelpers(prevMember, true); + ProcessState processState = localState.processIdToState.get(prevMember.processId); + processState.addTask(prevMember.memberId, task, true); + maybeUpdateTasksPerMember(processState.activeTaskCount()); it.remove(); } } // 3. assign any remaining unassigned tasks + PriorityQueue processByLoad = new PriorityQueue<>(Comparator.comparingDouble(ProcessState::load)); + processByLoad.addAll(localState.processIdToState.values()); for (Iterator it = activeTasks.iterator(); it.hasNext();) { final TaskId task = it.next(); - final Set allMembers = localState.processIdToState.entrySet().stream().flatMap(entry -> entry.getValue().memberToTaskCounts().keySet().stream() - .map(memberId -> new Member(entry.getKey(), memberId))).collect(Collectors.toSet()); - final Member member = findMemberWithLeastLoad(allMembers, task, false); + ProcessState processWithLeastLoad = processByLoad.poll(); + if (processWithLeastLoad == null) { + throw new TaskAssignorException("No process available to assign active task {}." + task); + } + String member = memberWithLeastLoad(processWithLeastLoad); if (member == null) { - log.error("Unable to assign active task {} to any member.", task); throw new TaskAssignorException("No member available to assign active task {}." + task); } - localState.processIdToState.get(member.processId).addTask(member.memberId, task, true); + processWithLeastLoad.addTask(member, task, true); it.remove(); - updateHelpers(member, true); - + maybeUpdateTasksPerMember(processWithLeastLoad.activeTaskCount()); + processByLoad.add(processWithLeastLoad); // Add it back to the queue after updating its state } } @@ -214,29 +221,75 @@ private void maybeUpdateTasksPerMember(final int activeTasksNo) { } } - private Member findMemberWithLeastLoad(final Set members, TaskId taskId, final boolean returnSameMember) { + private boolean assignStandbyToMemberWithLeastLoad(PriorityQueue queue, TaskId taskId) { + ProcessState processWithLeastLoad = queue.poll(); + if (processWithLeastLoad == null) { + return false; + } + boolean found = false; + if (!processWithLeastLoad.hasTask(taskId)) { + String memberId = memberWithLeastLoad(processWithLeastLoad); + if (memberId != null) { + processWithLeastLoad.addTask(memberId, taskId, false); + found = true; + } + } else if (!queue.isEmpty()) { + found = assignStandbyToMemberWithLeastLoad(queue, taskId); + } + queue.add(processWithLeastLoad); // Add it back to the queue after updating its state + return found; + } + + /** + * Finds the previous member with the least load for a given task. + * + * @param members The list of previous members owning the task. + * @param taskId The taskId, to check if the previous member already has the task. Can be null, if we assign it + * for the first time (e.g., during active task assignment). + * + * @return Previous member with the least load that does not have the task, or null if no such member exists. + */ + private Member findPrevMemberWithLeastLoad(final ArrayList members, final TaskId taskId) { if (members == null || members.isEmpty()) { return null; } - Optional processWithLeastLoad = members.stream() - .map(member -> localState.processIdToState.get(member.processId)) - .min(Comparator.comparingDouble(ProcessState::load)); - - // if the same exact former member is needed - if (returnSameMember) { - return localState.standbyTaskToPrevMember.get(taskId).stream() - .filter(standby -> standby.processId.equals(processWithLeastLoad.get().processId())) - .findFirst() - .orElseGet(() -> memberWithLeastLoad(processWithLeastLoad.get())); + + Member candidate = members.get(0); + ProcessState candidateProcessState = localState.processIdToState.get(candidate.processId); + double candidateProcessLoad = candidateProcessState.load(); + double candidateMemberLoad = candidateProcessState.memberToTaskCounts().get(candidate.memberId); + for (int i = 1; i < members.size(); i++) { + Member member = members.get(i); + ProcessState processState = localState.processIdToState.get(member.processId); + double newProcessLoad = processState.load(); + if (newProcessLoad < candidateProcessLoad && (taskId == null || !processState.hasTask(taskId))) { + double newMemberLoad = processState.memberToTaskCounts().get(member.memberId); + if (newMemberLoad < candidateMemberLoad) { + candidateProcessLoad = newProcessLoad; + candidateMemberLoad = newMemberLoad; + candidate = member; + } + } } - return memberWithLeastLoad(processWithLeastLoad.get()); + + if (taskId == null || !candidateProcessState.hasTask(taskId)) { + return candidate; + } + return null; } - private Member memberWithLeastLoad(final ProcessState processWithLeastLoad) { + private String memberWithLeastLoad(final ProcessState processWithLeastLoad) { + Map members = processWithLeastLoad.memberToTaskCounts(); + if (members.isEmpty()) { + return null; + } + if (members.size() == 1) { + return members.keySet().iterator().next(); + } Optional memberWithLeastLoad = processWithLeastLoad.memberToTaskCounts().entrySet().stream() .min(Map.Entry.comparingByValue()) .map(Map.Entry::getKey); - return memberWithLeastLoad.map(memberId -> new Member(processWithLeastLoad.processId(), memberId)).orElse(null); + return memberWithLeastLoad.orElse(null); } private boolean hasUnfulfilledQuota(final Member member) { @@ -244,55 +297,49 @@ private boolean hasUnfulfilledQuota(final Member member) { } private void assignStandby(final Set standbyTasks, final int numStandbyReplicas) { + ArrayList toLeastLoaded = new ArrayList<>(standbyTasks.size() * numStandbyReplicas); for (TaskId task : standbyTasks) { for (int i = 0; i < numStandbyReplicas; i++) { - final Set availableProcesses = localState.processIdToState.values().stream() - .filter(process -> !process.hasTask(task)) - .map(ProcessState::processId) - .collect(Collectors.toSet()); - - if (availableProcesses.isEmpty()) { - log.warn("{} There is not enough available capacity. " + - "You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", - errorMessage(numStandbyReplicas, i, task)); - break; - } - Member standby = null; - // prev active task Member prevMember = localState.activeTaskToPrevMember.get(task); - if (prevMember != null && availableProcesses.contains(prevMember.processId) && isLoadBalanced(prevMember.processId)) { - standby = prevMember; + if (prevMember != null) { + ProcessState prevMemberProcessState = localState.processIdToState.get(prevMember.processId); + if (!prevMemberProcessState.hasTask(task) && isLoadBalanced(prevMemberProcessState)) { + prevMemberProcessState.addTask(prevMember.memberId, task, false); + continue; + } } // prev standby tasks - if (standby == null) { - final Set prevMembers = localState.standbyTaskToPrevMember.get(task); - if (prevMembers != null && !prevMembers.isEmpty()) { - prevMembers.removeIf(member -> !availableProcesses.contains(member.processId)); - prevMember = findMemberWithLeastLoad(prevMembers, task, true); - if (prevMember != null && isLoadBalanced(prevMember.processId)) { - standby = prevMember; + final ArrayList prevMembers = localState.standbyTaskToPrevMember.get(task); + if (prevMembers != null && !prevMembers.isEmpty()) { + prevMember = findPrevMemberWithLeastLoad(prevMembers, task); + if (prevMember != null) { + ProcessState prevMemberProcessState = localState.processIdToState.get(prevMember.processId); + if (isLoadBalanced(prevMemberProcessState)) { + prevMemberProcessState.addTask(prevMember.memberId, task, false); + continue; } } } - // others - if (standby == null) { - final Set availableMembers = availableProcesses.stream() - .flatMap(pId -> localState.processIdToState.get(pId).memberToTaskCounts().keySet().stream() - .map(mId -> new Member(pId, mId))).collect(Collectors.toSet()); - standby = findMemberWithLeastLoad(availableMembers, task, false); - if (standby == null) { - log.warn("{} Error in standby task assignment!", errorMessage(numStandbyReplicas, i, task)); - break; - } - } - localState.processIdToState.get(standby.processId).addTask(standby.memberId, task, false); - updateHelpers(standby, false); + toLeastLoaded.add(new StandbyToAssign(task, numStandbyReplicas - i)); + break; } + } + PriorityQueue processByLoad = new PriorityQueue<>(Comparator.comparingDouble(ProcessState::load)); + processByLoad.addAll(localState.processIdToState.values()); + for (StandbyToAssign toAssign : toLeastLoaded) { + for (int i = 0; i < toAssign.remainingReplicas; i++) { + if (!assignStandbyToMemberWithLeastLoad(processByLoad, toAssign.taskId)) { + log.warn("{} There is not enough available capacity. " + + "You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", + errorMessage(numStandbyReplicas, i, toAssign.taskId)); + break; + } + } } } @@ -301,21 +348,13 @@ private String errorMessage(final int numStandbyReplicas, final int i, final Tas " of " + numStandbyReplicas + " standby tasks for task [" + task + "]."; } - private boolean isLoadBalanced(final String processId) { - final ProcessState process = localState.processIdToState.get(processId); + private boolean isLoadBalanced(final ProcessState process) { final double load = process.load(); boolean isLeastLoadedProcess = localState.processIdToState.values().stream() .allMatch(p -> p.load() >= load); return process.hasCapacity() || isLeastLoadedProcess; } - private void updateHelpers(final Member member, final boolean isActive) { - if (isActive) { - // update task per process - maybeUpdateTasksPerMember(localState.processIdToState.get(member.processId).activeTaskCount()); - } - } - private static int computeTasksPerMember(final int numberOfTasks, final int numberOfMembers) { if (numberOfMembers == 0) { return 0; @@ -327,6 +366,16 @@ private static int computeTasksPerMember(final int numberOfTasks, final int numb return tasksPerMember; } + static class StandbyToAssign { + private final TaskId taskId; + private final int remainingReplicas; + + public StandbyToAssign(final TaskId taskId, final int remainingReplicas) { + this.taskId = taskId; + this.remainingReplicas = remainingReplicas; + } + } + static class Member { private final String processId; private final String memberId; @@ -340,11 +389,11 @@ public Member(final String processId, final String memberId) { private static class LocalState { // helper data structures: Map activeTaskToPrevMember; - Map> standbyTaskToPrevMember; + Map> standbyTaskToPrevMember; Map processIdToState; int allTasks; int totalCapacity; int tasksPerMember; } -} \ No newline at end of file +} diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java index c40b3433a91a7..d9df2077b2cdf 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java @@ -140,15 +140,19 @@ public void shouldWorkWithRebalance( final Properties props = new Properties(); + final String appId = safeUniqueTestName(testInfo); props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); - props.put(StreamsConfig.APPLICATION_ID_CONFIG, safeUniqueTestName(testInfo)); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled); props.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled); - // decrease the session timeout so that we can trigger the rebalance soon after old client left closed - props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); - props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500); if (streamsProtocolEnabled) { props.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.name().toLowerCase(Locale.getDefault())); + // decrease the session timeout so that we can trigger the rebalance soon after old client left closed + CLUSTER.setGroupSessionTimeout(appId, 10000); + CLUSTER.setGroupHeartbeatTimeout(appId, 1000); + } else { + // decrease the session timeout so that we can trigger the rebalance soon after old client left closed + props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); } // cycle out Streams instances as long as the test is running. diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java index ec48b8b36349c..8c8ef3dae9c7d 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java @@ -99,7 +99,7 @@ private Properties streamsConfiguration(final boolean streamsProtocolEnabled) { streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class); if (streamsProtocolEnabled) { streamsConfiguration.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.name().toLowerCase(Locale.getDefault())); - CLUSTER.setStandbyReplicas("app-" + safeTestName, 1); + CLUSTER.setGroupStandbyReplicas("app-" + safeTestName, 1); } else { streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); } diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java index f425f8365ee57..1de7a45bfce7a 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java @@ -401,6 +401,8 @@ public KafkaConsumer createConsumerAndSubscribeTo(final Map