diff --git a/.github/workflows/cpp_integration.yml b/.github/workflows/cpp_integration.yml index 1c521f5d0a0..4f04c4badcf 100644 --- a/.github/workflows/cpp_integration.yml +++ b/.github/workflows/cpp_integration.yml @@ -85,24 +85,31 @@ jobs: check-latest: false - name: Compile & Install Celeborn Java run: build/mvn clean install -DskipTests - - name: Run Java-Cpp Hybrid Integration Test + - name: Run Java-Write Cpp-Read Hybrid Integration Test (NONE Decompression) run: | build/mvn -pl worker \ test-compile exec:java \ -Dexec.classpathScope="test" \ -Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithNONE" \ -Dexec.args="-XX:MaxDirectMemorySize=2G" - - name: Run Java-Cpp Hybrid Integration Test (LZ4 Decompression) + - name: Run Java-Write Cpp-Read Hybrid Integration Test (LZ4 Decompression) run: | build/mvn -pl worker \ test-compile exec:java \ -Dexec.classpathScope="test" \ -Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithLZ4" \ -Dexec.args="-XX:MaxDirectMemorySize=2G" - - name: Run Java-Cpp Hybrid Integration Test (ZSTD Decompression) + - name: Run Java-Write Cpp-Read Hybrid Integration Test (ZSTD Decompression) run: | build/mvn -pl worker \ test-compile exec:java \ -Dexec.classpathScope="test" \ -Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithZSTD" \ -Dexec.args="-XX:MaxDirectMemorySize=2G" + - name: Run Cpp-Write Java-Read Hybrid Integration Test (NONE Compression) + run: | + build/mvn -pl worker \ + test-compile exec:java \ + -Dexec.classpathScope="test" \ + -Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.CppWriteJavaReadTestWithNONE" \ + -Dexec.args="-XX:MaxDirectMemorySize=2G" \ No newline at end of file diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java index 87a80006a4d..6f52b4aab5e 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java @@ -46,6 +46,7 @@ import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.network.protocol.PushData; +import org.apache.celeborn.common.network.protocol.SerdeVersion; import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.util.TransportConf; import org.apache.celeborn.common.protocol.MessageType; @@ -528,7 +529,7 @@ public Optional regionStart( public Optional revive( int shuffleId, int mapId, int attemptId, PartitionLocation location) throws CelebornIOException { - Set mapIds = new HashSet<>(); + List mapIds = new ArrayList<>(); mapIds.add(mapId); List requests = new ArrayList<>(); ReviveRequest req = @@ -543,7 +544,7 @@ public Optional revive( requests.add(req); PbChangeLocationResponse response = lifecycleManagerRef.askSync( - ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests), + ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests, SerdeVersion.V1), conf.clientRpcRequestPartitionLocationAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); // per partitionKey only serve single PartitionLocation in Client Cache. diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index cfe40a2968f..f6a7d9750e3 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -550,11 +550,11 @@ private ConcurrentHashMap registerShuffle( numPartitions, () -> lifecycleManagerRef.askSync( - RegisterShuffle$.MODULE$.apply(shuffleId, numMappers, numPartitions), + new RegisterShuffle(shuffleId, numMappers, numPartitions, SerdeVersion.V1), conf.clientRpcRegisterShuffleAskTimeout(), rpcMaxRetries, rpcRetryWait, - ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class))); + ClassTag$.MODULE$.apply(RegisterShuffleResponse.class))); } @Override @@ -593,7 +593,7 @@ public PartitionLocation registerMapPartitionTask( partitionId, isSegmentGranularityVisible), conf.clientRpcRegisterShuffleAskTimeout(), - ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class))); + ClassTag$.MODULE$.apply(RegisterShuffleResponse.class))); return partitionLocationMap.get(partitionId); } @@ -709,23 +709,18 @@ public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdent } private ConcurrentHashMap registerShuffleInternal( - int shuffleId, - int numMappers, - int numPartitions, - Callable callable) + int shuffleId, int numMappers, int numPartitions, Callable callable) throws CelebornIOException { int numRetries = registerShuffleMaxRetries; StatusCode lastFailedStatusCode = null; while (numRetries > 0) { try { - PbRegisterShuffleResponse response = callable.call(); - StatusCode respStatus = StatusCode.fromValue(response.getStatus()); + RegisterShuffleResponse response = callable.call(); + StatusCode respStatus = response.status(); if (StatusCode.SUCCESS.equals(respStatus)) { ConcurrentHashMap result = JavaUtils.newConcurrentHashMap(); - Tuple2, List> locations = - PbSerDeUtils.fromPbPackedPartitionLocationsPair( - response.getPackedPartitionLocationsPair()); - for (PartitionLocation location : locations._1) { + PartitionLocation[] locations = response.partitionLocations(); + for (PartitionLocation location : locations) { pushExcludedWorkers.remove(location.hostAndPushPort()); if (location.hasPeer()) { pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); @@ -900,43 +895,43 @@ Map reviveBatch( oldLocMap.put(req.partitionId, req.loc); } try { - PbChangeLocationResponse response = + ChangeLocationResponse response = lifecycleManagerRef.askSync( - Revive$.MODULE$.apply(shuffleId, mapIds, requests), + Revive$.MODULE$.apply( + shuffleId, new ArrayList<>(mapIds), new ArrayList<>(requests), SerdeVersion.V1), conf.clientRpcRequestPartitionLocationAskTimeout(), - ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); + ClassTag$.MODULE$.apply(ChangeLocationResponse.class)); - for (int i = 0; i < response.getEndedMapIdCount(); i++) { - int mapId = response.getEndedMapId(i); + for (int i = 0; i < response.endedMapIds().size(); i++) { + int mapId = response.endedMapIds().get(i); mapperEndMap.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()).add(mapId); } - for (int i = 0; i < response.getPartitionInfoCount(); i++) { - PbChangeLocationPartitionInfo partitionInfo = response.getPartitionInfo(i); - int partitionId = partitionInfo.getPartitionId(); - int statusCode = partitionInfo.getStatus(); - if (partitionInfo.getOldAvailable()) { + for (Map.Entry> entry : + response.newLocs().entrySet()) { + int partitionId = entry.getKey(); + StatusCode statusCode = entry.getValue()._1(); + if (entry.getValue()._2() != null) { PartitionLocation oldLoc = oldLocMap.get(partitionId); // Currently, revive only check if main location available, here won't remove peer loc. pushExcludedWorkers.remove(oldLoc.hostAndPushPort()); } - if (StatusCode.SUCCESS.getValue() == statusCode) { - PartitionLocation loc = - PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition()); + if (StatusCode.SUCCESS == statusCode) { + PartitionLocation loc = entry.getValue()._3(); partitionLocationMap.put(partitionId, loc); pushExcludedWorkers.remove(loc.hostAndPushPort()); if (loc.hasPeer()) { pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort()); } - } else if (StatusCode.STAGE_ENDED.getValue() == statusCode) { + } else if (StatusCode.STAGE_ENDED == statusCode) { stageEndShuffleSet.add(shuffleId); return results; - } else if (StatusCode.SHUFFLE_UNREGISTERED.getValue() == statusCode) { + } else if (StatusCode.SHUFFLE_UNREGISTERED == statusCode) { logger.error("SHUFFLE_NOT_REGISTERED!"); return null; } - results.put(partitionId, statusCode); + results.put(partitionId, (int) (statusCode.getValue())); } return results; @@ -1806,7 +1801,8 @@ private void mapEndInternal( pushState.getFailedBatches(), numPartitions, crc32PerPartition, - bytesPerPartition), + bytesPerPartition, + SerdeVersion.V1), rpcMaxRetries, rpcRetryWait, ClassTag$.MODULE$.apply(MapperEndResponse.class)); diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 23189853544..f48f3cd72ba 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -156,7 +156,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } case class RegisterCallContext(context: RpcCallContext, partitionId: Int = -1) { - def reply(response: PbRegisterShuffleResponse) = { + def reply(response: RegisterShuffleResponse) = { context.reply(response) } } @@ -360,14 +360,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case pb: PbRegisterShuffle => - val shuffleId = pb.getShuffleId - val numMappers = pb.getNumMappers - val numPartitions = pb.getNumPartitions + case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) => logDebug(s"Received RegisterShuffle request, " + s"$shuffleId, $numMappers, $numPartitions.") offerAndReserveSlots( RegisterCallContext(context), + serdeVersion, shuffleId, numMappers, numPartitions) @@ -384,31 +382,25 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP) offerAndReserveSlots( RegisterCallContext(context, partitionId), + // Use V1 as this is only supported in java + SerdeVersion.V1, shuffleId, numMappers, numMappers, partitionId, isSegmentGranularityVisible) - case pb: PbRevive => - val shuffleId = pb.getShuffleId - val mapIds = pb.getMapIdList - val partitionInfos = pb.getPartitionInfoList - + case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) => val partitionIds = new util.ArrayList[Integer]() val epochs = new util.ArrayList[Integer]() val oldPartitions = new util.ArrayList[PartitionLocation]() val causes = new util.ArrayList[StatusCode]() - (0 until partitionInfos.size()).foreach { idx => - val info = partitionInfos.get(idx) - partitionIds.add(info.getPartitionId) - epochs.add(info.getEpoch) - if (info.hasPartition) { - oldPartitions.add(PbSerDeUtils.fromPbPartitionLocation(info.getPartition)) - } else { - oldPartitions.add(null) - } - causes.add(StatusCode.fromValue(info.getStatus)) + (0 until reviveRequests.size()).foreach { idx => + val request = reviveRequests.get(idx) + partitionIds.add(request.partitionId) + epochs.add(request.epoch) + oldPartitions.add(request.loc) + causes.add(request.cause) } logDebug(s"Received Revive request, number of partitions ${partitionIds.size()}") handleRevive( @@ -418,7 +410,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionIds, epochs, oldPartitions, - causes) + causes, + serdeVersion) case pb: PbPartitionSplit => val shuffleId = pb.getShuffleId @@ -428,7 +421,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logTrace(s"Received split request, " + s"$shuffleId, $partitionId, $epoch, $oldPartition") changePartitionManager.handleRequestPartitionLocation( - ChangeLocationsCallContext(context, 1), + // TODO: this message is not supported in cppClient yet. + ChangeLocationsCallContext(context, 1, SerdeVersion.V1), shuffleId, partitionId, epoch, @@ -444,7 +438,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends pushFailedBatch, numPartitions, crc32PerPartition, - bytesWrittenPerPartition) => + bytesWrittenPerPartition, + serdeVersion) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) @@ -459,7 +454,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends pushFailedBatch, numPartitions, crc32PerPartition, - bytesWrittenPerPartition) + bytesWrittenPerPartition, + serdeVersion) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -467,7 +463,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends mapId, attemptId, partitionId, - numMappers) + numMappers, + serdeVersion) case _ => throw new UnsupportedOperationException(s"Not support $partitionType yet") } @@ -618,6 +615,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private def offerAndReserveSlots( context: RegisterCallContext, + serdeVersion: SerdeVersion, shuffleId: Int, numMappers: Int, numPartitions: Int, @@ -641,13 +639,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends processMapTaskReply( shuffleId, rpcContext, + serdeVersion, partitionId, getLatestLocs(shuffleId, p => p.getId == partitionId)) case PartitionType.REDUCE => if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) { context.reply(RegisterShuffleResponse( StatusCode.SUCCESS, - getLatestLocs(shuffleId, _ => true))) + getLatestLocs(shuffleId, _ => true), + serdeVersion)) } else { val cachedMsg = registerShuffleResponseRpcCache.get( shuffleId, @@ -656,7 +656,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize( RegisterShuffleResponse( StatusCode.SUCCESS, - getLatestLocs(shuffleId, _ => true))) + getLatestLocs(shuffleId, _ => true), + serdeVersion)) } }) rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg) @@ -699,15 +700,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends def processMapTaskReply( shuffleId: Int, context: RpcCallContext, + serdeVersion: SerdeVersion, partitionId: Int, partitionLocations: Array[PartitionLocation]): Unit = { // if any partition location resource exist just reply if (partitionLocations.size > 0) { - context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, partitionLocations)) + context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, partitionLocations, serdeVersion)) } else { // request new resource for this task changePartitionManager.handleRequestPartitionLocation( - ApplyNewLocationCallContext(context), + ApplyNewLocationCallContext(context, serdeVersion), shuffleId, partitionId, -1, @@ -717,13 +719,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } // Reply to all RegisterShuffle request for current shuffle id. - def replyRegisterShuffle(response: PbRegisterShuffleResponse): Unit = { + def replyRegisterShuffle(response: RegisterShuffleResponse): Unit = { registeringShuffleRequest.synchronized { val serializedMsg: Option[ByteBuffer] = partitionType match { case PartitionType.REDUCE => context.context match { case remoteContext: RemoteNettyRpcCallContext => - if (response.getStatus == StatusCode.SUCCESS.getValue) { + if (response.status == StatusCode.SUCCESS) { Option(remoteContext.nettyEnv.serialize( response)) } else { @@ -735,19 +737,19 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case _ => Option.empty } - val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair( - response.getPackedPartitionLocationsPair)._1.asScala + val locations = response.partitionLocations registeringShuffleRequest.asScala .get(shuffleId) .foreach(_.asScala.foreach(context => { partitionType match { case PartitionType.MAP => - if (response.getStatus == StatusCode.SUCCESS.getValue) { - val partitionLocations = locations.filter(_.getId == context.partitionId).toArray + if (response.status == StatusCode.SUCCESS) { + val partitionLocations = locations.filter(_.getId == context.partitionId) processMapTaskReply( shuffleId, context.context, + serdeVersion, context.partitionId, partitionLocations) } else { @@ -757,7 +759,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } case PartitionType.REDUCE => if (context.context.isInstanceOf[ - LocalNettyRpcCallContext] || response.getStatus != StatusCode.SUCCESS.getValue) { + LocalNettyRpcCallContext] || response.status != StatusCode.SUCCESS) { context.reply(response) } else { registerShuffleResponseRpcCache.put(shuffleId, serializedMsg.get) @@ -780,17 +782,26 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends res.status match { case StatusCode.REQUEST_FAILED => logInfo(s"OfferSlots RPC request failed for $shuffleId!") - replyRegisterShuffle(RegisterShuffleResponse(StatusCode.REQUEST_FAILED, Array.empty)) + replyRegisterShuffle(RegisterShuffleResponse( + StatusCode.REQUEST_FAILED, + Array.empty, + serdeVersion)) return case StatusCode.SLOT_NOT_AVAILABLE => logInfo(s"OfferSlots for $shuffleId failed!") - replyRegisterShuffle(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty)) + replyRegisterShuffle(RegisterShuffleResponse( + StatusCode.SLOT_NOT_AVAILABLE, + Array.empty, + serdeVersion)) return case StatusCode.SUCCESS => logDebug(s"OfferSlots for $shuffleId Success!Slots Info: ${res.workerResource}") case StatusCode.WORKER_EXCLUDED => logInfo(s"OfferSlots for $shuffleId failed due to all workers be excluded!") - replyRegisterShuffle(RegisterShuffleResponse(StatusCode.WORKER_EXCLUDED, Array.empty)) + replyRegisterShuffle(RegisterShuffleResponse( + StatusCode.WORKER_EXCLUDED, + Array.empty, + serdeVersion)) return case _ => // won't happen throw new UnsupportedOperationException() @@ -823,7 +834,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return. if (!reserveSlotsSuccess) { logError(s"reserve buffer for $shuffleId failed, reply to all.") - replyRegisterShuffle(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty)) + replyRegisterShuffle(RegisterShuffleResponse( + StatusCode.RESERVE_SLOTS_FAILED, + Array.empty, + serdeVersion)) } else { if (log.isDebugEnabled()) { logDebug(s"ReserveSlots for $shuffleId success with details:$slots!") @@ -851,7 +865,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val allPrimaryPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray replyRegisterShuffle(RegisterShuffleResponse( StatusCode.SUCCESS, - allPrimaryPartitionLocations)) + allPrimaryPartitionLocations, + serdeVersion)) } } @@ -862,9 +877,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionIds: util.List[Integer], oldEpochs: util.List[Integer], oldPartitions: util.List[PartitionLocation], - causes: util.List[StatusCode]): Unit = { + causes: util.List[StatusCode], + serdeVersion: SerdeVersion): Unit = { val contextWrapper = - ChangeLocationsCallContext(context, partitionIds.size()) + ChangeLocationsCallContext(context, partitionIds.size(), serdeVersion) // If shuffle not registered, reply ShuffleNotRegistered and return if (!registeredShuffle.contains(shuffleId)) { logError(s"[handleRevive] shuffle $shuffleId not registered!") @@ -916,7 +932,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends pushFailedBatches: util.Map[String, LocationPushFailedBatches], numPartitions: Int, crc32PerPartition: Array[Int], - bytesWrittenPerPartition: Array[Long]): Unit = { + bytesWrittenPerPartition: Array[Long], + serdeVersion: SerdeVersion): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = commitManager.finishMapperAttempt( @@ -936,7 +953,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } // reply success - context.reply(MapperEndResponse(StatusCode.SUCCESS)) + context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion)) } private def handleGetReducerFileGroup( @@ -1205,7 +1222,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends mapId: Int, attemptId: Int, partitionId: Int, - numMappers: Int): Unit = { + numMappers: Int, + serdeVersion: SerdeVersion): Unit = { def reply(result: Boolean): Unit = { val message = s"to handle MapPartitionEnd for ${Utils.makeMapKey(shuffleId, mapId, attemptId)}, " + @@ -1213,10 +1231,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends result match { case true => // if already committed by another try logDebug(s"Succeed $message") - context.reply(MapperEndResponse(StatusCode.SUCCESS)) + context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion)) case false => logError(s"Failed $message, reply ${StatusCode.SHUFFLE_DATA_LOST}.") - context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST)) + context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST, serdeVersion)) } } diff --git a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala index 091960a4c46..9de71dd46fc 100644 --- a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala +++ b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala @@ -20,6 +20,7 @@ package org.apache.celeborn.client import java.util import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.network.protocol.SerdeVersion import org.apache.celeborn.common.protocol.PartitionLocation import org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse, RegisterShuffleResponse} import org.apache.celeborn.common.protocol.message.StatusCode @@ -36,11 +37,12 @@ trait RequestLocationCallContext { case class ChangeLocationsCallContext( context: RpcCallContext, - partitionCount: Int) + partitionCount: Int, + serdeVersion: SerdeVersion) extends RequestLocationCallContext with Logging { - val endedMapIds = new util.HashSet[Integer]() + val endedMapIds = new util.ArrayList[Integer]() val newLocs = - JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean, PartitionLocation)]( + JavaUtils.newConcurrentHashMap[Integer, (StatusCode, java.lang.Boolean, PartitionLocation)]( partitionCount) def markMapperEnd(mapId: Int): Unit = this.synchronized { @@ -59,12 +61,13 @@ case class ChangeLocationsCallContext( if (newLocs.size() == partitionCount || StatusCode.SHUFFLE_UNREGISTERED == status || StatusCode.STAGE_ENDED == status) { - context.reply(ChangeLocationResponse(endedMapIds, newLocs)) + context.reply(ChangeLocationResponse(endedMapIds, newLocs, serdeVersion)) } } } -case class ApplyNewLocationCallContext(context: RpcCallContext) extends RequestLocationCallContext { +case class ApplyNewLocationCallContext(context: RpcCallContext, serdeVersion: SerdeVersion) + extends RequestLocationCallContext { override def reply( partitionId: Int, status: StatusCode, @@ -72,8 +75,8 @@ case class ApplyNewLocationCallContext(context: RpcCallContext) extends RequestL available: Boolean): Unit = { partitionLocationOpt match { case Some(partitionLocation) => - context.reply(RegisterShuffleResponse(status, Array(partitionLocation))) - case None => context.reply(RegisterShuffleResponse(status, Array.empty)) + context.reply(RegisterShuffleResponse(status, Array(partitionLocation), serdeVersion)) + case None => context.reply(RegisterShuffleResponse(status, Array.empty, serdeVersion)) } } } diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java index 7a7706973be..2cc2cd1fd62 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java @@ -30,9 +30,9 @@ import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.network.protocol.SerdeVersion; import org.apache.celeborn.common.protocol.CompressionCodec; import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse; import org.apache.celeborn.common.protocol.message.ControlMessages; import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.rpc.RpcEndpointRef; @@ -90,12 +90,14 @@ protected CelebornConf setupEnv(CompressionCodec codec) throws IOException, Inte primaryLocation.setPeer(replicaLocation); when(endpointRef.askSync( - ControlMessages.RegisterShuffle$.MODULE$.apply(TEST_SHUFFLE_ID, 1, 1), - ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class))) + new ControlMessages.RegisterShuffle(TEST_SHUFFLE_ID, 1, 1, SerdeVersion.V1), + ClassTag$.MODULE$.apply(ControlMessages.RegisterShuffleResponse.class))) .thenAnswer( t -> - ControlMessages.RegisterShuffleResponse$.MODULE$.apply( - StatusCode.SUCCESS, new PartitionLocation[] {primaryLocation})); + new ControlMessages.RegisterShuffleResponse( + StatusCode.SUCCESS, + new PartitionLocation[] {primaryLocation}, + SerdeVersion.V1)); shuffleClient.setupLifecycleManagerRef(endpointRef); when(clientFactory.createClient( diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index 0f4b5c30f4c..e6d450d87f2 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -263,13 +263,13 @@ private CelebornConf setupEnv( .thenAnswer( t -> RegisterShuffleResponse$.MODULE$.apply( - statusCode, new PartitionLocation[] {primaryLocation})); + statusCode, new PartitionLocation[] {primaryLocation}, SerdeVersion.V1)); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) .thenAnswer( t -> RegisterShuffleResponse$.MODULE$.apply( - statusCode, new PartitionLocation[] {primaryLocation})); + statusCode, new PartitionLocation[] {primaryLocation}, SerdeVersion.V1)); shuffleClient.setupLifecycleManagerRef(endpointRef); diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index eb9274632de..36f164d697e 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -131,17 +131,12 @@ object ControlMessages extends Logging { workerEvent: WorkerEventType = WorkerEventType.None) extends MasterMessage - object RegisterShuffle { - def apply( - shuffleId: Int, - numMappers: Int, - numPartitions: Int): PbRegisterShuffle = - PbRegisterShuffle.newBuilder() - .setShuffleId(shuffleId) - .setNumMappers(numMappers) - .setNumPartitions(numPartitions) - .build() - } + case class RegisterShuffle( + shuffleId: Int, + numMappers: Int, + numPartitions: Int, + serdeVersion: SerdeVersion) + extends MasterMessage object RegisterMapPartitionTask { def apply( @@ -161,17 +156,10 @@ object ControlMessages extends Logging { .build() } - object RegisterShuffleResponse { - def apply( - status: StatusCode, - partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse = { - val builder = PbRegisterShuffleResponse.newBuilder() - .setStatus(status.getValue) - builder.setPackedPartitionLocationsPair( - PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList)) - builder.build() - } - } + case class RegisterShuffleResponse( + status: StatusCode, + partitionLocations: Array[PartitionLocation], + serdeVersion: SerdeVersion) extends MasterMessage case class RequestSlots( applicationId: String, @@ -195,29 +183,11 @@ object ControlMessages extends Logging { packed: Boolean = false) extends MasterMessage - object Revive { - def apply( - shuffleId: Int, - mapIds: util.Set[Integer], - reviveRequests: util.Collection[ReviveRequest]): PbRevive = { - val builder = PbRevive.newBuilder() - .setShuffleId(shuffleId) - .addAllMapId(mapIds) - - reviveRequests.asScala.foreach { req => - val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder() - .setPartitionId(req.partitionId) - .setEpoch(req.epoch) - .setStatus(req.cause.getValue) - if (req.loc != null) { - partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc)) - } - builder.addPartitionInfo(partitionInfoBuilder.build()) - } - - builder.build() - } - } + case class Revive( + shuffleId: Int, + mapIds: util.List[Integer], + reviveRequests: util.List[ReviveRequest], + serdeVersion: SerdeVersion) extends MasterMessage object PartitionSplit { def apply( @@ -233,26 +203,10 @@ object ControlMessages extends Logging { .build() } - object ChangeLocationResponse { - def apply( - mapIds: util.Set[Integer], - newLocs: util.Map[Integer, (StatusCode, Boolean, PartitionLocation)]) - : PbChangeLocationResponse = { - val builder = PbChangeLocationResponse.newBuilder() - builder.addAllEndedMapId(mapIds) - newLocs.asScala.foreach { case (partitionId, (status, available, loc)) => - val pbChangeLocationPartitionInfoBuilder = PbChangeLocationPartitionInfo.newBuilder() - .setPartitionId(partitionId) - .setStatus(status.getValue) - .setOldAvailable(available) - if (loc != null) { - pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc)) - } - builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build()) - } - builder.build() - } - } + case class ChangeLocationResponse( + endedMapIds: util.List[Integer], + newLocs: util.Map[Integer, (StatusCode, java.lang.Boolean, PartitionLocation)], + serdeVersion: SerdeVersion) extends MasterMessage case class MapperEnd( shuffleId: Int, @@ -263,7 +217,8 @@ object ControlMessages extends Logging { failedBatchSet: util.Map[String, LocationPushFailedBatches], numPartitions: Int, crc32PerPartition: Array[Int], - bytesWrittenPerPartition: Array[Long]) + bytesWrittenPerPartition: Array[Long], + serdeVersion: SerdeVersion) extends MasterMessage case class ReadReducerPartitionEnd( @@ -275,7 +230,7 @@ object ControlMessages extends Logging { bytesWritten: Long) extends MasterMessage - case class MapperEndResponse(status: StatusCode) extends MasterMessage + case class MapperEndResponse(status: StatusCode, serdeVersion: SerdeVersion) extends MasterMessage case class ReadReducerPartitionEndResponse(status: StatusCode) extends MasterMessage @@ -674,14 +629,23 @@ object ControlMessages extends Logging { .build().toByteArray new TransportMessage(MessageType.HEARTBEAT_FROM_WORKER_RESPONSE, payload) - case pb: PbRegisterShuffle => - new TransportMessage(MessageType.REGISTER_SHUFFLE, pb.toByteArray) + case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) => + val payload = PbRegisterShuffle.newBuilder() + .setShuffleId(shuffleId) + .setNumMappers(numMappers) + .setNumPartitions(numPartitions) + .build().toByteArray + new TransportMessage(MessageType.REGISTER_SHUFFLE, payload, serdeVersion) case pb: PbRegisterMapPartitionTask => new TransportMessage(MessageType.REGISTER_MAP_PARTITION_TASK, pb.toByteArray) - case pb: PbRegisterShuffleResponse => - new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE, pb.toByteArray) + case RegisterShuffleResponse(status, partitionLocations, serdeVersion) => + val payload = PbRegisterShuffleResponse.newBuilder() + .setStatus(status.getValue).setPackedPartitionLocationsPair( + PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList)) + .build().toByteArray + new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE, payload, serdeVersion) case RequestSlots( applicationId, @@ -729,11 +693,39 @@ object ControlMessages extends Logging { val payload = builder.build().toByteArray new TransportMessage(MessageType.REQUEST_SLOTS_RESPONSE, payload) - case pb: PbRevive => - new TransportMessage(MessageType.CHANGE_LOCATION, pb.toByteArray) + case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) => + val builder = PbRevive.newBuilder() + .setShuffleId(shuffleId) + .addAllMapId(mapIds) - case pb: PbChangeLocationResponse => - new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) + reviveRequests.asScala.foreach { req => + val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder() + .setPartitionId(req.partitionId) + .setEpoch(req.epoch) + .setStatus(req.cause.getValue) + if (req.loc != null) { + partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc)) + } + builder.addPartitionInfo(partitionInfoBuilder.build()) + } + val payload = builder.build().toByteArray + new TransportMessage(MessageType.CHANGE_LOCATION, payload, serdeVersion) + + case ChangeLocationResponse(mapIds, newLocs, serdeVersion) => + val builder = PbChangeLocationResponse.newBuilder() + builder.addAllEndedMapId(mapIds) + newLocs.asScala.foreach { case (partitionId, (status, available, loc)) => + val pbChangeLocationPartitionInfoBuilder = PbChangeLocationPartitionInfo.newBuilder() + .setPartitionId(partitionId) + .setStatus(status.getValue) + .setOldAvailable(available) + if (loc != null) { + pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc)) + } + builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build()) + } + val payload = builder.build().toByteArray + new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, payload, serdeVersion) case MapperEnd( shuffleId, @@ -744,7 +736,8 @@ object ControlMessages extends Logging { pushFailedBatch, numPartitions, crc32PerPartition, - bytesWrittenPerPartition) => + bytesWrittenPerPartition, + serdeVersion) => val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v) (k, resultValue) @@ -761,13 +754,13 @@ object ControlMessages extends Logging { .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map( java.lang.Long.valueOf).toSeq.asJava) .build().toByteArray - new TransportMessage(MessageType.MAPPER_END, payload) + new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion) - case MapperEndResponse(status) => + case MapperEndResponse(status, serdeVersion) => val payload = PbMapperEndResponse.newBuilder() .setStatus(status.getValue) .build().toByteArray - new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload) + new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload, serdeVersion) case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, serdeVersion) => val payload = PbGetReducerFileGroup.newBuilder() @@ -1132,13 +1125,23 @@ object ControlMessages extends Logging { pbHeartbeatFromWorkerResponse.getWorkerEventType) case REGISTER_SHUFFLE_VALUE => - PbRegisterShuffle.parseFrom(message.getPayload) + val pbRegisterShuffle = PbRegisterShuffle.parseFrom(message.getPayload) + RegisterShuffle( + pbRegisterShuffle.getShuffleId, + pbRegisterShuffle.getNumMappers, + pbRegisterShuffle.getNumPartitions, + message.getSerdeVersion) case REGISTER_MAP_PARTITION_TASK_VALUE => PbRegisterMapPartitionTask.parseFrom(message.getPayload) case REGISTER_SHUFFLE_RESPONSE_VALUE => - PbRegisterShuffleResponse.parseFrom(message.getPayload) + val pbRegisterShuffleResponse = PbRegisterShuffleResponse.parseFrom(message.getPayload) + RegisterShuffleResponse( + StatusCode.fromValue(pbRegisterShuffleResponse.getStatus), + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + pbRegisterShuffleResponse.getPackedPartitionLocationsPair)._1.asScala.toArray, + message.getSerdeVersion) case REQUEST_SLOTS_VALUE => val pbRequestSlots = PbRequestSlots.parseFrom(message.getPayload) @@ -1175,10 +1178,51 @@ object ControlMessages extends Logging { workerResource) case CHANGE_LOCATION_VALUE => - PbRevive.parseFrom(message.getPayload) + val pbRevive = PbRevive.parseFrom(message.getPayload) + val shuffleId = pbRevive.getShuffleId + val partitionInfos = pbRevive.getPartitionInfoList + val reviveRequests = new util.ArrayList[ReviveRequest]() + (0 until partitionInfos.size).foreach { idx => + val info = partitionInfos.get(idx) + var partition: PartitionLocation = null + if (info.hasPartition) { + partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition) + } + val reviveRequest = new ReviveRequest( + shuffleId, + -1, + -1, + info.getPartitionId, + info.getEpoch, + partition, + StatusCode.fromValue(info.getStatus)) + reviveRequests.add(reviveRequest) + } + Revive( + pbRevive.getShuffleId, + pbRevive.getMapIdList, + reviveRequests, + message.getSerdeVersion) case CHANGE_LOCATION_RESPONSE_VALUE => - PbChangeLocationResponse.parseFrom(message.getPayload) + val pbChangeLocationResponse = PbChangeLocationResponse.parseFrom(message.getPayload) + val newLocs = + new util.HashMap[Integer, (StatusCode, java.lang.Boolean, PartitionLocation)]() + val partitionInfos = pbChangeLocationResponse.getPartitionInfoList + (0 until partitionInfos.size).foreach { idx => + val info = partitionInfos.get(idx) + var partition: PartitionLocation = null + if (info.hasPartition) { + partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition) + } + newLocs.put( + info.getPartitionId, + (StatusCode.fromValue(info.getStatus), info.getOldAvailable, partition)) + } + ChangeLocationResponse( + pbChangeLocationResponse.getEndedMapIdList, + newLocs, + message.getSerdeVersion) case MAPPER_END_VALUE => val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload) @@ -1203,7 +1247,8 @@ object ControlMessages extends Logging { }.toMap.asJava, pbMapperEnd.getNumPartitions, crc32Array, - bytesWrittenPerPartitionArray) + bytesWrittenPerPartitionArray, + message.getSerdeVersion) case READ_REDUCER_PARTITION_END_VALUE => val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload) @@ -1220,7 +1265,9 @@ object ControlMessages extends Logging { case MAPPER_END_RESPONSE_VALUE => val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload) - MapperEndResponse(StatusCode.fromValue(pbMapperEndResponse.getStatus)) + MapperEndResponse( + StatusCode.fromValue(pbMapperEndResponse.getStatus), + message.getSerdeVersion) case GET_REDUCER_FILE_GROUP_VALUE => val pbGetReducerFileGroup = PbGetReducerFileGroup.parseFrom(message.getPayload) diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 7cadaf07ebf..8be472b6447 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -28,6 +28,7 @@ import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.client.{MasterEndpointResolver, StaticMasterEndpointResolver} import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.identity.DefaultIdentityProvider +import org.apache.celeborn.common.network.protocol.SerdeVersion import org.apache.celeborn.common.protocol.{PartitionLocation, TransportModuleConstants} import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, MapperEnd} import org.apache.celeborn.common.protocol.message.StatusCode @@ -149,7 +150,17 @@ class UtilsSuite extends CelebornFunSuite { test("MapperEnd class convert with pb") { val mapperEnd = - MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray, Array.emptyLongArray) + MapperEnd( + 1, + 1, + 1, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId) diff --git a/cpp/celeborn/tests/CMakeLists.txt b/cpp/celeborn/tests/CMakeLists.txt index 0bc5e41c902..104607ce2f0 100644 --- a/cpp/celeborn/tests/CMakeLists.txt +++ b/cpp/celeborn/tests/CMakeLists.txt @@ -35,4 +35,29 @@ target_link_libraries( add_executable(cppDataSumWithReaderClient DataSumWithReaderClient.cpp) -target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient) \ No newline at end of file +target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient) + +add_library( + dataSumWithWriterClient + DataSumWithWriterClient.cpp) + +target_link_libraries( + dataSumWithWriterClient + memory + utils + conf + proto + network + protocol + client + ${WANGLE} + ${FIZZ} + ${LIBSODIUM_LIBRARY} + ${FOLLY_WITH_DEPENDENCIES} + ${GLOG} + ${GFLAGS_LIBRARIES} +) + +add_executable(cppDataSumWithWriterClient DataSumWithWriterClient.cpp) + +target_link_libraries(cppDataSumWithWriterClient dataSumWithWriterClient) diff --git a/cpp/celeborn/tests/DataSumWithWriterClient.cpp b/cpp/celeborn/tests/DataSumWithWriterClient.cpp new file mode 100644 index 00000000000..ceaa01bebc9 --- /dev/null +++ b/cpp/celeborn/tests/DataSumWithWriterClient.cpp @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +int main(int argc, char** argv) { + folly::init(&argc, &argv, false); + // Read the configs. + assert(argc == 9); + std::string lifecycleManagerHost = argv[1]; + int lifecycleManagerPort = std::atoi(argv[2]); + std::string appUniqueId = argv[3]; + int shuffleId = std::atoi(argv[4]); + int attemptId = std::atoi(argv[5]); + int numMappers = std::atoi(argv[6]); + int numPartitions = std::atoi(argv[7]); + std::string resultFile = argv[8]; + std::cout << "lifecycleManagerHost = " << lifecycleManagerHost + << ", lifecycleManagerPort = " << lifecycleManagerPort + << ", appUniqueId = " << appUniqueId + << ", shuffleId = " << shuffleId << ", attemptId = " << attemptId + << ", numMappers = " << numMappers + << ", numPartitions = " << numPartitions + << ", resultFile = " << resultFile << std::endl; + + // Create shuffleClient and setup. + auto conf = std::make_shared(); + auto clientEndpoint = + std::make_shared(conf); + auto shuffleClient = celeborn::client::ShuffleClientImpl::create( + appUniqueId, conf, *clientEndpoint); + shuffleClient->setupLifecycleManagerRef( + lifecycleManagerHost, lifecycleManagerPort); + + long maxData = 1000000; + size_t numData = 1000; + // Generate data, sum up and pushData. + std::vector result(numPartitions, 0); + std::vector dataCnt(numPartitions, 0); + for (int mapId = 0; mapId < numMappers; mapId++) { + for (int partitionId = 0; partitionId < numPartitions; partitionId++) { + std::string partitionData; + for (size_t i = 0; i < numData; i++) { + int data = std::rand() % maxData; + result[partitionId] += data; + dataCnt[partitionId]++; + partitionData += "-" + std::to_string(data); + } + shuffleClient->pushData( + shuffleId, + mapId, + attemptId, + partitionId, + reinterpret_cast(partitionData.c_str()), + 0, + partitionData.size(), + numMappers, + numPartitions); + } + shuffleClient->mapperEnd(shuffleId, mapId, attemptId, numMappers); + } + for (int partitionId = 0; partitionId < numPartitions; partitionId++) { + std::cout << "partition " << partitionId + << " sum result = " << result[partitionId] + << ", dataCnt = " << dataCnt[partitionId] << std::endl; + } + + // Write result to resultFile. + remove(resultFile.c_str()); + std::ofstream of(resultFile); + for (int partitionId = 0; partitionId < numPartitions; partitionId++) { + of << result[partitionId] << std::endl; + } + of.close(); + + return 0; +} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala new file mode 100644 index 00000000000..b7fb62a4af2 --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.cluster + +import org.apache.celeborn.common.protocol.CompressionCodec + +object CppWriteJavaReadTestWithNONE extends JavaCppHybridReadWriteTestBase { + + def main(args: Array[String]) = { + testCppWriteJavaRead(CompressionCodec.NONE) + } +} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala similarity index 60% rename from worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala rename to worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala index e059754e381..325f9c8b77a 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.client.read.MetricsCallback import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging @@ -35,7 +36,7 @@ import org.apache.celeborn.common.protocol.CompressionCodec import org.apache.celeborn.common.util.Utils.runCommand import org.apache.celeborn.service.deploy.MiniClusterFeature -trait JavaWriteCppReadTestBase extends AnyFunSuite +trait JavaCppHybridReadWriteTestBase extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll { var masterPort = 0 @@ -147,4 +148,99 @@ trait JavaWriteCppReadTestBase extends AnyFunSuite shuffleClient.shutdown() } + def testCppWriteJavaRead(codec: CompressionCodec): Unit = { + beforeAll() + try { + runCppWriteJavaRead(codec) + } finally { + afterAll() + } + } + + def runCppWriteJavaRead(codec: CompressionCodec): Unit = { + val appUniqueId = "test-app" + val shuffleId = 0 + val attemptId = 0 + + // Create lifecycleManager. + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, codec.name) + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") + .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + .set(CelebornConf.READ_LOCAL_SHUFFLE_FILE, false) + .set("celeborn.data.io.numConnectionsPerPeer", "1") + val lifecycleManager = new LifecycleManager(appUniqueId, clientConf) + + // Create writer shuffleClient. + val shuffleClient = + new ShuffleClientImpl(appUniqueId, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + val numMappers = 2 + val numPartitions = 2 + + // Launch cpp writer to write data, calculate result and write to specific result file. + val cppResultFile = "/tmp/celeborn-cpp-writer-result.txt" + val lifecycleManagerHost = lifecycleManager.getHost + val lifecycleManagerPort = lifecycleManager.getPort + val projectDirectory = new File(new File(".").getAbsolutePath) + val cppBinRelativeDirectory = "cpp/build/celeborn/tests/" + val cppBinFileName = "cppDataSumWithWriterClient" + val cppBinFilePath = s"$projectDirectory/$cppBinRelativeDirectory/$cppBinFileName" + // Execution command: $exec lifecycleManagerHost lifecycleManagerPort appUniqueId shuffleId attemptId numMappers numPartitions cppResultFile + val command = { + s"$cppBinFilePath $lifecycleManagerHost $lifecycleManagerPort $appUniqueId $shuffleId $attemptId $numMappers $numPartitions $cppResultFile" + } + println(s"run command: $command") + val commandOutput = runCommand(command) + println(s"command output: $commandOutput") + + val metricsCallback = new MetricsCallback { + override def incBytesRead(bytesWritten: Long): Unit = {} + override def incReadTime(time: Long): Unit = {} + } + + var sums = new util.ArrayList[Long](numPartitions) + for (partitionId <- 0 until numPartitions) { + sums.add(0) + val inputStream = shuffleClient.readPartition( + shuffleId, + partitionId, + attemptId, + 0, + 0, + Integer.MAX_VALUE, + metricsCallback) + var c = inputStream.read() + var data: Long = 0 + var dataCnt = 0 + while (c != -1) { + if (c == '-') { + sums.set(partitionId, sums.get(partitionId) + data) + data = 0 + dataCnt += 1 + } else { + assert(c >= '0' && c <= '9') + data *= 10 + data += c - '0' + } + c = inputStream.read() + } + sums.set(partitionId, sums.get(partitionId) + data) + println(s"partition $partitionId sum result = ${sums.get(partitionId)}, dataCnt = $dataCnt") + } + + // Verify the sum result. + var lineCount = 0 + for (line <- Source.fromFile(cppResultFile, "utf-8").getLines.toList) { + val data = line.toLong + Assert.assertEquals(data, sums.get(lineCount)) + lineCount += 1 + } + Assert.assertEquals(lineCount, numPartitions) + lifecycleManager.stop() + shuffleClient.shutdown() + } + } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala index bc1961384ec..327754ed980 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala @@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster import org.apache.celeborn.common.protocol.CompressionCodec -object JavaWriteCppReadTestWithLZ4 extends JavaWriteCppReadTestBase { +object JavaWriteCppReadTestWithLZ4 extends JavaCppHybridReadWriteTestBase { def main(args: Array[String]) = { testJavaWriteCppRead(CompressionCodec.LZ4) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala index a649f8350ef..18bb8a418ca 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala @@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster import org.apache.celeborn.common.protocol.CompressionCodec -object JavaWriteCppReadTestWithNONE extends JavaWriteCppReadTestBase { +object JavaWriteCppReadTestWithNONE extends JavaCppHybridReadWriteTestBase { def main(args: Array[String]) = { testJavaWriteCppRead(CompressionCodec.NONE) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala index f2ba2e769c0..de7cdf10235 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala @@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster import org.apache.celeborn.common.protocol.CompressionCodec -object JavaWriteCppReadTestWithZSTD extends JavaWriteCppReadTestBase { +object JavaWriteCppReadTestWithZSTD extends JavaCppHybridReadWriteTestBase { def main(args: Array[String]) = { testJavaWriteCppRead(CompressionCodec.ZSTD)