Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support saving to local filesystem path."
)
}
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
extraMetadata = Some(extraMetadata))
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
}
}

Expand All @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support loading from local filesystem path."
)
}
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
Expand All @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
(metadata.metadata \ "numTrainingRecords").extract[Long]
}
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
val itemSupport = if (numTrainingRecords == 0L) {
Map.empty[Any, Double]
} else {
Expand Down
31 changes: 30 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite
import org.apache.spark.util.{Utils, VersionUtils}

/**
Expand Down Expand Up @@ -1142,4 +1143,32 @@ private[spark] object ReadWriteUtils {
spark.read.parquet(path).as[T].collect()
}
}

def saveDataFrame(path: String, df: DataFrame): Unit = {
if (localSavingModeState.get()) {
df match {
case d: org.apache.spark.sql.classic.DataFrame =>
val filePath = Paths.get(path)
Files.createDirectories(filePath.getParent)
ArrowFileReadWrite.save(d, filePath)
case o => throw new UnsupportedOperationException(
s"Unsupported dataframe type: ${o.getClass.getName}")
}
} else {
df.write.parquet(path)
}
}

def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
if (localSavingModeState.get()) {
spark match {
case s: org.apache.spark.sql.classic.SparkSession =>
ArrowFileReadWrite.load(s, Paths.get(path))
case o => throw new UnsupportedOperationException(
s"Unsupported session type: ${o.getClass.getName}")
}
} else {
spark.read.parquet(path)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
FPGrowthSuite.allParamSettings, checkModelData)
}
}

Expand Down
32 changes: 24 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2340,14 +2340,13 @@ class Dataset[T] private[sql](
}

/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
private def toArrowBatchRddImpl(
plan: SparkPlan,
maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
val largeVarTypes =
sparkSession.sessionState.conf.arrowUseLargeVarTypes
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
Expand All @@ -2361,7 +2360,24 @@ class Dataset[T] private[sql](
}
}

// This is only used in tests, for now.
private[sql] def toArrowBatchRdd(
maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
toArrowBatchRddImpl(queryExecution.executedPlan,
maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
}

private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
toArrowBatchRddImpl(
plan,
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy",
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
}

private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ private[sql] object ArrowConverters extends Logging {
}

override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))
var bytes: Array[Byte] = null

Utils.tryWithSafeFinally {
var rowCount = 0L
Expand All @@ -140,13 +139,13 @@ private[sql] object ArrowConverters extends Logging {
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
bytes = serializeBatch(batch)
batch.close()
} {
arrowWriter.reset()
}

out.toByteArray
bytes
}

override def close(): Unit = {
Expand Down Expand Up @@ -548,32 +547,55 @@ private[sql] object ArrowConverters extends Logging {
new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException
}

private[arrow] def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))
MessageSerializer.serialize(writeChannel, batch)
out.toByteArray
}

/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
def toDataFrame(
arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
toDataFrame(
arrowBatches,
DataType.fromJson(schemaString).asInstanceOf[StructType],
session,
session.sessionState.conf.sessionLocalTimeZone,
false,
session.sessionState.conf.arrowUseLargeVarTypes)
}

/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
arrowBatches: Iterator[Array[Byte]],
schema: StructType,
session: SparkSession,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): DataFrame = {
val attrs = toAttributes(schema)
val batchesInDriver = arrowBatches.toArray
val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
val shouldUseRDD = session.sessionState.conf
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum

if (shouldUseRDD) {
logDebug("Using RDD-based createDataFrame with Arrow optimization.")
val timezone = session.sessionState.conf.sessionLocalTimeZone
val rdd = session.sparkContext
.parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length)
.mapPartitions { batchesInExecutors =>
ArrowConverters.fromBatchIterator(
batchesInExecutors,
schema,
timezone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
Expand All @@ -582,9 +604,9 @@ private[sql] object ArrowConverters extends Logging {
val data = ArrowConverters.fromBatchIterator(
batchesInDriver.iterator,
schema,
session.sessionState.conf.sessionLocalTimeZone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())

// Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.nio.channels.Channels
import java.nio.file.{Files, Path}

import scala.jdk.CollectionConverters._

import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter}
import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkArrowFileWriter(schema: Schema, path: Path) extends AutoCloseable {
private val allocator = ArrowUtils.rootAllocator
.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val root = VectorSchemaRoot.create(schema, allocator)
protected val loader = new VectorLoader(root)

protected val fileWriter =
new ArrowFileWriter(root, null, Channels.newChannel(Files.newOutputStream(path)))

override def close(): Unit = {
fileWriter.close()
root.close()
allocator.close()
}

def write(batchBytesIter: Iterator[Array[Byte]]): Unit = {
fileWriter.start()
while (batchBytesIter.hasNext) {
val batchBytes = batchBytesIter.next()
val batch = ArrowConverters.loadBatch(batchBytes, allocator)
Copy link
Contributor Author

@zhengruifeng zhengruifeng Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch: ArrowRecordBatch doesn't extends Serializable, so still use the Array[Byte] as the underlying data in the PR.

loader.load(batch)
fileWriter.writeBatch()
batch.close()
}
fileWriter.close()
}
}

private[sql] class SparkArrowFileReader(path: Path) extends AutoCloseable {
private val allocator = ArrowUtils.rootAllocator
.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val fileReader =
new ArrowFileReader(Files.newByteChannel(path), allocator)

override def close(): Unit = {
fileReader.close()
allocator.close()
}

val schema: Schema = fileReader.getVectorSchemaRoot.getSchema

def read(): Iterator[Array[Byte]] = {
fileReader.getRecordBlocks.iterator().asScala.map { block =>
fileReader.loadRecordBatch(block)
val root = fileReader.getVectorSchemaRoot
val unloader = new VectorUnloader(root)
val batch = unloader.getRecordBatch
val bytes = ArrowConverters.serializeBatch(batch)
batch.close()
bytes
}
}
}

private[spark] object ArrowFileReadWrite {
def save(df: DataFrame, path: Path): Unit = {
val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false)
val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false)
val writer = new SparkArrowFileWriter(arrowSchema, path)
writer.write(rdd.toLocalIterator)
Copy link
Member

@viirya viirya Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, can we call toLocalIterator on original DataFrame's rdd and write rows to Arrow batches locally? Then we don't need to have the redundant Bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can make best of the ArrowConverters utils, if we use the Bytes

}

def load(spark: SparkSession, path: Path): DataFrame = {
val reader = new SparkArrowFileReader(path)
val schema = ArrowUtils.fromArrowSchema(reader.schema)
ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false)
}
}
Loading