diff --git a/build.sbt b/build.sbt index 7ba262c62..3881a35b9 100644 --- a/build.sbt +++ b/build.sbt @@ -240,7 +240,14 @@ lazy val datasetSettings = mc("frameless.functions.FramelessLit"), mc(f"frameless.functions.FramelessLit$$"), dmm("frameless.functions.package.litAggr"), - dmm("org.apache.spark.sql.FramelessInternals.column") + dmm("org.apache.spark.sql.FramelessInternals.column"), + dmm("frameless.TypedEncoder.collectionEncoder"), + dmm("frameless.TypedEncoder.setEncoder"), + dmm("frameless.functions.FramelessUdf.evalCode"), + dmm("frameless.functions.FramelessUdf.copy"), + dmm("frameless.functions.FramelessUdf.this"), + dmm("frameless.functions.FramelessUdf.apply"), + imt("frameless.functions.FramelessUdf.apply") ) }, coverageExcludedPackages := "org.apache.spark.sql.reflection", diff --git a/dataset/src/main/scala/frameless/CollectionCaster.scala b/dataset/src/main/scala/frameless/CollectionCaster.scala new file mode 100644 index 000000000..bf329992e --- /dev/null +++ b/dataset/src/main/scala/frameless/CollectionCaster.scala @@ -0,0 +1,67 @@ +package frameless + +import frameless.TypedEncoder.CollectionConversion +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenContext, + CodegenFallback, + ExprCode +} +import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression } +import org.apache.spark.sql.types.{ DataType, ObjectType } + +case class CollectionCaster[F[_], C[_], Y]( + child: Expression, + conversion: CollectionConversion[F, C, Y]) + extends UnaryExpression + with CodegenFallback { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: F[Y] @unchecked => + conversion.convert(col) + case _ => o + } + } + + override def dataType: DataType = child.dataType +} + +case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) + extends UnaryExpression { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + // eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: Set[Y] @unchecked => + col.toSeq + case _ => o + } + } + + def toSeqOr[T](isSet: => T, or: => T): T = + child.dataType match { + case ObjectType(cls) + if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + isSet + case t => or + } + + override def dataType: DataType = + toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType) + + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = + defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c")) + +} diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b42b026ee..928a05d6e 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -1,15 +1,10 @@ package frameless import java.math.BigInteger - import java.util.Date - -import java.time.{ Duration, Instant, Period, LocalDate } - +import java.time.{ Duration, Instant, LocalDate, Period } import java.sql.Timestamp - import scala.reflect.ClassTag - import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.FramelessInternals.UserDefinedType import org.apache.spark.sql.{ reflection => ScalaReflection } @@ -22,10 +17,11 @@ import org.apache.spark.sql.catalyst.util.{ } import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import shapeless._ import shapeless.ops.hlist.IsHCons +import scala.collection.immutable.{ ListSet, TreeSet } + abstract class TypedEncoder[T]( implicit val classTag: ClassTag[T]) @@ -501,10 +497,76 @@ object TypedEncoder { override def toString: String = s"arrayEncoder($jvmRepr)" } - implicit def collectionEncoder[C[X] <: Seq[X], T]( + /** + * Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation + * + * This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported. + * + * @tparam C + */ + trait CollectionConversion[F[_], C[_], Y] extends Serializable { + def convert(c: F[Y]): C[Y] + } + + object CollectionConversion { + + implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] { + + override def convert(c: Seq[Y]): Seq[Y] = + c match { + // Stream is produced + case _: Stream[Y] @unchecked => c.toVector.toSeq + case _ => c + } + } + + implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] { + override def convert(c: Seq[Y]): Vector[Y] = c.toVector + } + + implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] { + override def convert(c: Seq[Y]): List[Y] = c.toList + } + + implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] { + override def convert(c: Set[Y]): Set[Y] = c + } + + implicit def setToTreeSet[Y]( + implicit + ordering: Ordering[Y] + ) = new CollectionConversion[Set, TreeSet, Y] { + + override def convert(c: Set[Y]): TreeSet[Y] = + TreeSet.newBuilder.++=(c).result() + } + + implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] { + + override def convert(c: Set[Y]): ListSet[Y] = + ListSet.newBuilder.++=(c).result() + } + } + + implicit def seqEncoder[C[X] <: Seq[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Seq, C, T] + ) = collectionEncoder[Seq, C, T] + + implicit def setEncoder[C[X] <: Set[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Set, C, T] + ) = collectionEncoder[Set, C, T] + + def collectionEncoder[O[_], C[X], T]( implicit i0: Lazy[RecordFieldEncoder[T]], - i1: ClassTag[C[T]] + i1: ClassTag[C[T]], + i2: CollectionConversion[O, C, T] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { private lazy val encodeT = i0.value.encoder @@ -521,38 +583,31 @@ object TypedEncoder { if (ScalaReflection.isNativeType(enc.jvmRepr)) { NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) } else { - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) + // converts to Seq, both Set and Seq handling must convert to Seq first + MapObjects( + enc.toCatalyst, + SeqCaster(path), + enc.jvmRepr, + encodeT.nullable + ) } } def fromCatalyst(path: Expression): Expression = - MapObjects( - i0.value.fromCatalyst, - path, - encodeT.catalystRepr, - encodeT.nullable, - Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) + CollectionCaster[O, C, T]( + MapObjects( + i0.value.fromCatalyst, + path, + encodeT.catalystRepr, + encodeT.nullable, + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling + ), + implicitly[CollectionConversion[O, C, T]] + ) // This will convert Seq to the appropriate C[_] when eval'ing. override def toString: String = s"collectionEncoder($jvmRepr)" } - /** - * @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set. - * @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type. - * @tparam T the element type of the set. - * @return a `TypedEncoder` instance for `Set[T]`. - */ - implicit def setEncoder[T]( - implicit - i1: shapeless.Lazy[RecordFieldEncoder[T]], - i2: ClassTag[Set[T]] - ): TypedEncoder[Set[T]] = { - implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet) - - TypedEncoder.usingInjection - } - /** * @tparam A the key type * @tparam B the value type diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index 93ba7f118..c34e8561e 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,132 +2,179 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + LeafExpression, + NonSQLExpression +} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.DataType import shapeless.syntax.std.tuple._ -/** Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ +/** + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ trait Udf { - /** Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A, R: TypedEncoder](f: A => R): TypedColumn[T, A] => TypedColumn[T, R] = { u => - val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R]) + val scalaUdf = FramelessUdf( + f, + List(u), + TypedEncoder[R], + s => f(s.head.asInstanceOf[A]) + ) new TypedColumn[T, R](scalaUdf) } - /** Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, R: TypedEncoder](f: (A1, A2) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2]) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1, A2, A3) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4], + s(4).asInstanceOf[A5] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } } /** - * NB: Implementation detail, isn't intended to be directly used. - * - * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. - */ + * NB: Implementation detail, isn't intended to be directly used. + * + * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. + */ +// Possibly add UserDefinedExpression trait to stop the functions being registered and used as aggregates case class FramelessUdf[T, R]( - function: AnyRef, - encoders: Seq[TypedEncoder[_]], - children: Seq[Expression], - rencoder: TypedEncoder[R] -) extends Expression with NonSQLExpression { + function: AnyRef, + encoders: Seq[TypedEncoder[_]], + children: Seq[Expression], + rencoder: TypedEncoder[R], + evalFunction: Seq[Any] => Any) + extends Expression + with NonSQLExpression { override def nullable: Boolean = rencoder.nullable + override def toString: String = s"FramelessUdf(${children.mkString(", ")})" - lazy val evalCode = { - val ctx = new CodegenContext() - val eval = genCode(ctx) + lazy val typedEnc = + TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]] - val codeBody = s""" - public scala.Function1 generate(Object[] references) { - return new FramelessUdfEvalImpl(references); - } + lazy val isSerializedAsStructForTopLevel = + typedEnc.isSerializedAsStructForTopLevel - class FramelessUdfEvalImpl extends scala.runtime.AbstractFunction1 { - private final Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} - - public FramelessUdfEvalImpl(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public java.lang.Object apply(java.lang.Object z) { - InternalRow ${ctx.INPUT_ROW} = (InternalRow) z; - ${eval.code} - return ${eval.isNull} ? ((Object)null) : ((Object)${eval.value}); - } - } - """ - - val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + def eval(input: InternalRow): Any = { + val jvmTypes = children.map(_.eval(input)) - val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + val returnJvm = evalFunction(jvmTypes).asInstanceOf[R] - codegen - } + val returnCatalyst = typedEnc.createSerializer().apply(returnJvm) + val retval = + if (returnCatalyst == null) + null + else if (isSerializedAsStructForTopLevel) + returnCatalyst + else + returnCatalyst.get(0, dataType) - def eval(input: InternalRow): Any = { - evalCode(input) + retval } def dataType: DataType = rencoder.catalystRepr @@ -139,29 +186,45 @@ case class FramelessUdf[T, R]( val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName val funcClassName = s"scala.Function${children.size}" val funcExpressionIdx = ctx.references.size - 1 - val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"), - v => s"$v = ($funcClassName)((($framelessUdfClassName)references" + - s"[$funcExpressionIdx]).function());") - - val (argsCode, funcArguments) = encoders.zip(children).map { - case (encoder, child) => - val eval = child.genCode(ctx) - val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) - val argTerm = ctx.freshName("arg") - val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + val funcTerm = ctx.addMutableState( + funcClassName, + ctx.freshName("udf"), + v => + s"$v = ($funcClassName)((($framelessUdfClassName)references" + + s"[$funcExpressionIdx]).function());" + ) - (convert, argTerm) - }.unzip + val (argsCode, funcArguments) = encoders + .zip(children) + .map { + case (encoder, child) => + val eval = child.genCode(ctx) + val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) + val argTerm = ctx.freshName("arg") + val convert = + s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + + (convert, argTerm) + } + .unzip val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr) - val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal")) - val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull")) + val internalTerm = + ctx.addMutableState(internalTpe, ctx.freshName("internal")) + val internalNullTerm = + ctx.addMutableState("boolean", ctx.freshName("internalNull")) // CTw - can't inject the term, may have to duplicate old code for parity - val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true) + val internalExpr = Spark2_4_LambdaVariable( + internalTerm, + internalNullTerm, + rencoder.jvmRepr, + true + ) val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx) - ev.copy(code = code""" + ev.copy( + code = code""" ${argsCode.mkString("\n")} $internalTerm = @@ -175,21 +238,28 @@ case class FramelessUdf[T, R]( ) } - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(children = newChildren) } case class Spark2_4_LambdaVariable( - value: String, - isNull: String, - dataType: DataType, - nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) + extends LeafExpression + with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = + InternalRow.getAccessor(dataType) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { - assert(input.numFields == 1, - "The input row of interpreted LambdaVariable should have only 1 field.") + assert( + input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field." + ) if (nullable && input.isNullAt(0)) { null } else { @@ -197,7 +267,10 @@ case class Spark2_4_LambdaVariable( } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = { val isNullValue = if (nullable) { JavaCode.isNullVariable(isNull) } else { @@ -208,15 +281,18 @@ case class Spark2_4_LambdaVariable( } object FramelessUdf { + // Spark needs case class with `children` field to mutate it def apply[T, R]( - function: AnyRef, - cols: Seq[UntypedExpression[T]], - rencoder: TypedEncoder[R] - ): FramelessUdf[T, R] = FramelessUdf( + function: AnyRef, + cols: Seq[UntypedExpression[T]], + rencoder: TypedEncoder[R], + evalFunction: Seq[Any] => Any + ): FramelessUdf[T, R] = FramelessUdf( function = function, encoders = cols.map(_.uencoder).toList, children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList, - rencoder = rencoder + rencoder = rencoder, + evalFunction = evalFunction ) } diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index 4ebf5d93f..ab1f35811 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -1,7 +1,6 @@ package frameless -import scala.collection.immutable.Set - +import scala.collection.immutable.{ ListSet, Set, TreeSet } import org.scalatest.matchers.should.Matchers object EncoderTests { @@ -10,6 +9,8 @@ object EncoderTests { case class InstantRow(i: java.time.Instant) case class DurationRow(d: java.time.Duration) case class PeriodRow(p: java.time.Period) + + case class ContainerOf[CC[X] <: Iterable[X]](a: CC[X1[Int]]) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -32,4 +33,55 @@ class EncoderTests extends TypedDatasetSuite with Matchers { test("It should encode java.time.Period") { implicitly[TypedEncoder[PeriodRow]] } + + def performCollection[C[X] <: Iterable[X]]( + toType: Seq[X1[Int]] => C[X1[Int]] + )(implicit + ce: TypedEncoder[C[X1[Int]]] + ): (Unit, Unit) = evalCodeGens { + + implicit val cte = TypedExpressionEncoder[C[X1[Int]]] + implicit val e = implicitly[TypedEncoder[ContainerOf[C]]] + implicit val te = TypedExpressionEncoder[ContainerOf[C]] + implicit val xe = implicitly[TypedEncoder[X1[ContainerOf[C]]]] + implicit val xte = TypedExpressionEncoder[X1[ContainerOf[C]]] + val v = toType((1 to 20).map(X1(_))) + val ds = { + sqlContext.createDataset(Seq(X1[ContainerOf[C]](ContainerOf[C](v)))) + } + ds.head.a.a shouldBe v + () + } + + test("It should serde a Seq of Objects") { + performCollection[Seq](_) + } + + test("It should serde a Set of Objects") { + performCollection[Set](_) + } + + test("It should serde a Vector of Objects") { + performCollection[Vector](_.toVector) + } + + test("It should serde a TreeSet of Objects") { + // only needed for 2.12 + implicit val ordering = new Ordering[X1[Int]] { + val intordering = implicitly[Ordering[Int]] + + override def compare(x: X1[Int], y: X1[Int]): Int = + intordering.compare(x.a, y.a) + } + + performCollection[TreeSet](TreeSet.newBuilder.++=(_).result()) + } + + test("It should serde a List of Objects") { + performCollection[List](_.toList) + } + + test("It should serde a ListSet of Objects") { + performCollection[ListSet](ListSet.newBuilder.++=(_).result()) + } } diff --git a/dataset/src/test/scala/frameless/functions/UdfTests.scala b/dataset/src/test/scala/frameless/functions/UdfTests.scala index 10e65180f..af452cba4 100644 --- a/dataset/src/test/scala/frameless/functions/UdfTests.scala +++ b/dataset/src/test/scala/frameless/functions/UdfTests.scala @@ -4,182 +4,257 @@ package functions import org.scalacheck.Prop import org.scalacheck.Prop._ +import scala.collection.immutable.{ ListSet, TreeSet } + class UdfTests extends TypedDatasetSuite { test("one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X1[A]], f1: A => B): Prop = { - val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) - val u1 = udf[X1[A], A, B](f1) - val u2 = dataset.makeUDF(f1) - val A = dataset.col[A]('a) - - // filter forces whole codegen - val codegen = dataset.deserialized.filter((_:X1[A]) => true).select(u1(A)).collect().run().toVector - - // otherwise it uses local relation - val local = dataset.select(u2(A)).collect().run().toVector - - val d = data.map(x => f1(x.a)) - - (codegen ?= d) && (local ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder]( + data: Vector[X1[A]], + f1: A => B + ): Prop = { + val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) + val u1 = udf[X1[A], A, B](f1) + val u2 = dataset.makeUDF(f1) + val A = dataset.col[A]('a) + + // filter forces whole codegen + val codegen = dataset.deserialized + .filter((_: X1[A]) => true) + .select(u1(A)) + .collect() + .run() + .toVector + + // otherwise it uses local relation + val local = dataset.select(u2(A)).collect().run().toVector + + val d = data.map(x => f1(x.a)) + + (codegen ?= d) && (local ?= d) + } + + check(forAll(prop[Int, Int] _)) + check(forAll(prop[String, String] _)) + check(forAll(prop[Option[Int], Option[Int]] _)) + check(forAll(prop[X1[Int], X1[Int]] _)) + check(forAll(prop[X1[Option[Int]], X1[Option[Int]]] _)) + + // TODO doesn't work for the same reason as `collect` + // check(forAll(prop[X1[Option[X1[Int]]], X1[Option[X1[Option[Int]]]]] _)) + + // Vector/List isn't supported by MapObjects, not all collections are equal see #804 + check(forAll(prop[Option[Seq[String]], Option[Seq[String]]] _)) + check(forAll(prop[Option[List[String]], Option[List[String]]] _)) + check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) + + // ListSet/TreeSet weren't supported before #804 + check(forAll(prop[Option[Set[String]], Option[Set[String]]] _)) + check(forAll(prop[Option[ListSet[String]], Option[ListSet[String]]] _)) + check(forAll(prop[Option[TreeSet[String]], Option[TreeSet[String]]] _)) + + def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = + prop(Vector(X1(a)), f) + + check( + forAll( + prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _ + ) + ) + check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } - - check(forAll(prop[Int, Int] _)) - check(forAll(prop[String, String] _)) - check(forAll(prop[Option[Int], Option[Int]] _)) - check(forAll(prop[X1[Int], X1[Int]] _)) - check(forAll(prop[X1[Option[Int]], X1[Option[Int]]] _)) - - // TODO doesn't work for the same reason as `collect` - // check(forAll(prop[X1[Option[X1[Int]]], X1[Option[X1[Option[Int]]]]] _)) - - check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) - - def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = prop(Vector(X1(a)), f) - - check(forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _)) - check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } test("multiple one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: A => A, f2: B => B, f3: C => C): Prop = { - val dataset = TypedDataset.create(data) - val u11 = udf[X3[A, B, C], A, A](f1) - val u21 = udf[X3[A, B, C], B, B](f2) - val u31 = udf[X3[A, B, C], C, C](f3) - val u12 = dataset.makeUDF(f1) - val u22 = dataset.makeUDF(f2) - val u32 = dataset.makeUDF(f3) - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector - val dataset22 = dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector - val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: A => A, + f2: B => B, + f3: C => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u11 = udf[X3[A, B, C], A, A](f1) + val u21 = udf[X3[A, B, C], B, B](f2) + val u31 = udf[X3[A, B, C], C, C](f3) + val u12 = dataset.makeUDF(f1) + val u22 = dataset.makeUDF(f2) + val u32 = dataset.makeUDF(f3) + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = + dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector + val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) + check(forAll(prop[X3[Int, String, Boolean], Int, Int] _)) + check(forAll(prop[X3U[Int, String, Boolean], Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) - check(forAll(prop[X3[Int, String, Boolean], Int, Int] _)) - check(forAll(prop[X3U[Int, String, Boolean], Int, Int] _)) } test("two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X3[A, B, C], A, B, C](f1) - val u2 = dataset.makeUDF(f1) - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - - val dataset21 = dataset.select(u1(A, B)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B)).collect().run().toVector - val d = data.map(x => f1(x.a, x.b)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X3[A, B, C], A, B, C](f1) + val u2 = dataset.makeUDF(f1) + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + + val dataset21 = dataset.select(u1(A, B)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B)).collect().run().toVector + val d = data.map(x => f1(x.a, x.b)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("multiple two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C, f2: (B, C) => A): Prop = { - val dataset = TypedDataset.create(data) - val u11 = udf[X3[A, B, C], A, B, C](f1) - val u12 = dataset.makeUDF(f1) - val u21 = udf[X3[A, B, C], B, C, A](f2) - val u22 = dataset.makeUDF(f2) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u11(A, B), u21(B, C)).collect().run().toVector - val dataset22 = dataset.select(u12(A, B), u22(B, C)).collect().run().toVector - val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C, + f2: (B, C) => A + ): Prop = { + val dataset = TypedDataset.create(data) + val u11 = udf[X3[A, B, C], A, B, C](f1) + val u12 = dataset.makeUDF(f1) + val u21 = udf[X3[A, B, C], B, C, A](f2) + val u22 = dataset.makeUDF(f2) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = + dataset.select(u11(A, B), u21(B, C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A, B), u22(B, C)).collect().run().toVector + val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("three argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f: (A, B, C) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X3[A, B, C], A, B, C, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f: (A, B, C) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X3[A, B, C], A, B, C, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) + } } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("four argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder] - (data: Vector[X4[A, B, C, D]], f: (A, B, C, D) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - - val dataset21 = dataset.select(u1(A, B, C, D)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]], + f: (A, B, C, D) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + + val dataset21 = + dataset.select(u1(A, B, C, D)).collect().run().toVector + val dataset22 = + dataset.select(u2(A, B, C, D)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int] _)) + check(forAll(prop[String, Int, Int, String] _)) + check(forAll(prop[String, String, String, String] _)) + check(forAll(prop[String, Long, String, String] _)) + check(forAll(prop[String, Boolean, Boolean, String] _)) + } } - - check(forAll(prop[Int, Int, Int, Int] _)) - check(forAll(prop[String, Int, Int, String] _)) - check(forAll(prop[String, String, String, String] _)) - check(forAll(prop[String, Long, String, String] _)) - check(forAll(prop[String, Boolean, Boolean, String] _)) } test("five argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder, E: TypedEncoder] - (data: Vector[X5[A, B, C, D, E]], f: (A, B, C, D, E) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - val E = dataset.col[E]('e) - - val dataset21 = dataset.select(u1(A, B, C, D, E)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D, E)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder, + E: TypedEncoder + ](data: Vector[X5[A, B, C, D, E]], + f: (A, B, C, D, E) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + val E = dataset.col[E]('e) + + val dataset21 = + dataset.select(u1(A, B, C, D, E)).collect().run().toVector + val dataset22 = + dataset.select(u2(A, B, C, D, E)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int, Int] _)) + } } - - check(forAll(prop[Int, Int, Int, Int, Int] _)) } } diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 82ff375c9..601613c81 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -1,9 +1,13 @@ import java.time.format.DateTimeFormatter -import java.time.{LocalDateTime => JavaLocalDateTime} +import java.time.{ LocalDateTime => JavaLocalDateTime } +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.scalacheck.{ Arbitrary, Cogen, Gen } -import org.scalacheck.{Arbitrary, Gen} +import scala.collection.immutable.{ ListSet, TreeSet } package object frameless { + /** Fixed decimal point to avoid precision problems specific to Spark */ implicit val arbBigDecimal: Arbitrary[BigDecimal] = Arbitrary { for { @@ -30,11 +34,62 @@ package object frameless { } // see issue with scalacheck non serializable Vector: https://github.com/rickynils/scalacheck/issues/315 - implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] = + implicit def arbVector[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[Vector[A]] = Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector)) def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary + implicit def arbSeq[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[scala.collection.Seq[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector.toSeq)) + + def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary + + implicit def arbList[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[List[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(_.toList)) + + def listGen[A: Arbitrary]: Gen[List[A]] = arbList[A].arbitrary + + implicit def arbSet[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[Set[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(Set.newBuilder.++=(_).result())) + + def setGen[A: Arbitrary]: Gen[Set[A]] = arbSet[A].arbitrary + + implicit def cogenListSet[A: Cogen: Ordering]: Cogen[ListSet[A]] = + Cogen.it(_.toVector.sorted.iterator) + + implicit def arbListSet[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[ListSet[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(ListSet.newBuilder.++=(_).result())) + + def listSetGen[A: Arbitrary]: Gen[ListSet[A]] = arbListSet[A].arbitrary + + implicit def cogenTreeSet[A: Cogen: Ordering]: Cogen[TreeSet[A]] = + Cogen.it(_.toVector.sorted.iterator) + + implicit def arbTreeSet[A]( + implicit + A: Arbitrary[A], + o: Ordering[A] + ): Arbitrary[TreeSet[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(TreeSet.newBuilder.++=(_).result())) + + def treeSetGen[A: Arbitrary: Ordering]: Gen[TreeSet[A]] = + arbTreeSet[A].arbitrary + implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary { for { int <- Arbitrary.arbitrary[Int] @@ -42,7 +97,8 @@ package object frameless { } yield new UdtEncodedClass(int, doubles.toArray) } - val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") + val dateTimeFormatter: DateTimeFormatter = + DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary { for { @@ -61,7 +117,18 @@ package object frameless { localDate <- listOfDates } yield localDate.format(dateTimeFormatter) - val TEST_OUTPUT_DIR = "target/test-output" + private var outputDir: String = _ + + /** allow usage on non-build environments */ + def setOutputDir(path: String): Unit = { + outputDir = path + } + + lazy val TEST_OUTPUT_DIR = + if (outputDir ne null) + outputDir + else + "target/test-output" /** * Will dive down causes until either the cause is true or there are no more causes @@ -72,11 +139,10 @@ package object frameless { def anyCauseHas(t: Throwable, f: Throwable => Boolean): Boolean = if (f(t)) true + else if (t.getCause ne null) + anyCauseHas(t.getCause, f) else - if (t.getCause ne null) - anyCauseHas(t.getCause, f) - else - false + false /** * Runs up to maxRuns and outputs the number of failures (times thrown) @@ -85,11 +151,11 @@ package object frameless { * @tparam T * @return the last passing thunk, or null */ - def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T ={ + def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T = { var i = 0 var r = null.asInstanceOf[T] var passed = 0 - while(i < maxRuns){ + while (i < maxRuns) { i += 1 try { r = thunk @@ -98,29 +164,36 @@ package object frameless { println(s"run $i successful") } } catch { - case t: Throwable => System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") + case t: Throwable => + System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") } } if (passed != maxRuns) { - System.err.println(s"had ${maxRuns - passed} failures out of $maxRuns runs") + System.err.println( + s"had ${maxRuns - passed} failures out of $maxRuns runs" + ) } r } - /** + /** * Runs a given thunk up to maxRuns times, restarting the thunk if tolerantOf the thrown Throwable is true * @param tolerantOf * @param maxRuns default of 20 * @param thunk * @return either a successful run result or the last error will be thrown */ - def tolerantRun[T](tolerantOf: Throwable => Boolean, maxRuns: Int = 20)(thunk: => T): T ={ + def tolerantRun[T]( + tolerantOf: Throwable => Boolean, + maxRuns: Int = 20 + )(thunk: => T + ): T = { var passed = false var i = 0 var res: T = null.asInstanceOf[T] var thrown: Throwable = null - while((i < maxRuns) && !passed) { + while ((i < maxRuns) && !passed) { try { i += 1 res = thunk @@ -139,4 +212,58 @@ package object frameless { } res } + + // from Quality, which is from Spark test versions + + // if this blows then debug on CodeGenerator 1294, 1299 and grab code.body + def forceCodeGen[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + def forceInterpreted[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + /** + * runs the same test with both eval and codegen, then does the same again using resolveWith + * + * @param f + * @tparam T + * @return + */ + def evalCodeGens[T](f: => T): (T, T) = + (forceInterpreted(f), forceCodeGen(f)) + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => conf.setConfString(k, v) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + }