Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support saving to local filesystem path."
)
}
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
extraMetadata = Some(extraMetadata))
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
}
}

Expand All @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support loading from local filesystem path."
)
}
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
Expand All @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
(metadata.metadata \ "numTrainingRecords").extract[Long]
}
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
val itemSupport = if (numTrainingRecords == 0L) {
Map.empty[Any, Double]
} else {
Expand Down
30 changes: 29 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite
import org.apache.spark.util.{Utils, VersionUtils}

/**
Expand Down Expand Up @@ -1142,4 +1143,31 @@ private[spark] object ReadWriteUtils {
spark.read.parquet(path).as[T].collect()
}
}

def saveDataFrame(path: String, df: DataFrame): Unit = {
if (localSavingModeState.get()) {
val filePath = Paths.get(path)
Files.createDirectories(filePath.getParent)

df match {
case d: org.apache.spark.sql.classic.DataFrame =>
ArrowFileReadWrite.save(d, path)
case _ => throw new UnsupportedOperationException("Unsupported dataframe type")
}
} else {
df.write.parquet(path)
}
}

def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
if (localSavingModeState.get()) {
spark match {
case s: org.apache.spark.sql.classic.SparkSession =>
ArrowFileReadWrite.load(s, path)
case _ => throw new UnsupportedOperationException("Unsupported session type")
}
} else {
spark.read.parquet(path)
}
}
Comment on lines 1147 to 1173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we have localSavingModeState set to true this will write out an arrow file which is not stable format wise. It does look like localSavingModeState is only set to true in internal methods in Scala. Looking in the PySpark docstrings I see we tell people to use this API so I remain -0.9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @holdenk , as @WeichenXu123 explained #53150 (comment), this is a runtime temporary file in spark connect server side, and will be cleaned after session close.
So I think we don't have to use a stable format here.

Copy link
Contributor

@WeichenXu123 WeichenXu123 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localSavingModeState is also used internally, (only Spark driver code can set the flag) . Where does the doc string mentioned it ? we should remove it from doc and mark localSavingModeState as private field

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, even it is just a temporary session file, is there any reason not to use Parquet but Arrow file format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can read/write parquet with arrow, but it requires a new dependency

<dependency>
    <groupId>org.apache.parquet</groupId>
    <artifactId>parquet-arrow</artifactId>
</dependency>

otherwise, I am not sure whether we have utils to read/write parquet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the end we need the in-memory data to be in arrow format, so using arrow file is more efficient.

}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
FPGrowthSuite.allParamSettings, checkModelData)
}
}

Expand Down
32 changes: 24 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2340,14 +2340,13 @@ class Dataset[T] private[sql](
}

/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
private def toArrowBatchRddImpl(
plan: SparkPlan,
maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
val largeVarTypes =
sparkSession.sessionState.conf.arrowUseLargeVarTypes
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
Expand All @@ -2361,7 +2360,24 @@ class Dataset[T] private[sql](
}
}

// This is only used in tests, for now.
private[sql] def toArrowBatchRdd(
maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
toArrowBatchRddImpl(queryExecution.executedPlan,
maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
}

private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
toArrowBatchRddImpl(
plan,
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy",
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
}

private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,25 +555,41 @@ private[sql] object ArrowConverters extends Logging {
arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
toDataFrame(
arrowBatches,
DataType.fromJson(schemaString).asInstanceOf[StructType],
session,
session.sessionState.conf.sessionLocalTimeZone,
false,
session.sessionState.conf.arrowUseLargeVarTypes)
}

/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
arrowBatches: Iterator[Array[Byte]],
schema: StructType,
session: SparkSession,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): DataFrame = {
val attrs = toAttributes(schema)
val batchesInDriver = arrowBatches.toArray
val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
val shouldUseRDD = session.sessionState.conf
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum

if (shouldUseRDD) {
logDebug("Using RDD-based createDataFrame with Arrow optimization.")
val timezone = session.sessionState.conf.sessionLocalTimeZone
val rdd = session.sparkContext
.parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length)
.mapPartitions { batchesInExecutors =>
ArrowConverters.fromBatchIterator(
batchesInExecutors,
schema,
timezone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
Expand All @@ -582,9 +598,9 @@ private[sql] object ArrowConverters extends Logging {
val data = ArrowConverters.fromBatchIterator(
batchesInDriver.iterator,
schema,
session.sessionState.conf.sessionLocalTimeZone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())

// Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.io.{ByteArrayOutputStream, FileOutputStream}
import java.nio.channels.Channels
import java.nio.file.Files
import java.nio.file.Paths

import scala.jdk.CollectionConverters._

import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter, WriteChannel}
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkArrowFileWriter(
arrowSchema: Schema,
path: String) extends AutoCloseable {
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
protected val loader = new VectorLoader(root)
protected val arrowWriter = ArrowWriter.create(root)

protected val fileWriter =
new ArrowFileWriter(root, null, Channels.newChannel(new FileOutputStream(path)))

override def close(): Unit = {
root.close()
allocator.close()
fileWriter.close()
}

def write(batchBytesIter: Iterator[Array[Byte]]): Unit = {
fileWriter.start()
while (batchBytesIter.hasNext) {
val batchBytes = batchBytesIter.next()
val batch = ArrowConverters.loadBatch(batchBytes, allocator)
Copy link
Contributor Author

@zhengruifeng zhengruifeng Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch: ArrowRecordBatch doesn't extends Serializable, so still use the Array[Byte] as the underlying data in the PR.

loader.load(batch)
fileWriter.writeBatch()
}
fileWriter.close()
}
}

private[sql] class SparkArrowFileReader(path: String) extends AutoCloseable {
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val fileReader =
new ArrowFileReader(Files.newByteChannel(Paths.get(path)), allocator)

override def close(): Unit = {
allocator.close()
fileReader.close()
}

val schema: Schema = fileReader.getVectorSchemaRoot.getSchema

def read(): Iterator[Array[Byte]] = {
fileReader.getRecordBlocks.iterator().asScala.map { block =>
fileReader.loadRecordBatch(block)
val root = fileReader.getVectorSchemaRoot
val unloader = new VectorUnloader(root)
val batch = unloader.getRecordBatch
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))
MessageSerializer.serialize(writeChannel, batch)
out.toByteArray
}
}
}

private[spark] object ArrowFileReadWrite {
def save(df: DataFrame, path: String): Unit = {
val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false)
val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false)
val writer = new SparkArrowFileWriter(arrowSchema, path)
writer.write(rdd.toLocalIterator)
Copy link
Member

@viirya viirya Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, can we call toLocalIterator on original DataFrame's rdd and write rows to Arrow batches locally? Then we don't need to have the redundant Bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can make best of the ArrowConverters utils, if we use the Bytes

}

def load(spark: SparkSession, path: String): DataFrame = {
val reader = new SparkArrowFileReader(path)
val schema = ArrowUtils.fromArrowSchema(reader.schema)
ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false)
}
}