Skip to content

Commit 4909dba

Browse files
authored
Merge pull request #589 from waitinfuture/support-complex-type
support complex type for MaxCompute
2 parents fb8fb4e + 91e42fe commit 4909dba

File tree

7 files changed

+379
-310
lines changed

7 files changed

+379
-310
lines changed

emr-maxcompute/src/main/scala/org/apache/spark/aliyun/odps/OdpsOps.scala

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
*/
1717
package org.apache.spark.aliyun.odps
1818

19-
import java.sql.SQLException
2019
import java.text.SimpleDateFormat
21-
2220
import scala.reflect.ClassTag
2321

2422
import com.aliyun.odps._
@@ -27,15 +25,14 @@ import com.aliyun.odps.account.AliyunAccount
2725
import com.aliyun.odps.data.Record
2826
import com.aliyun.odps.tunnel.TableTunnel
2927
import com.aliyun.odps.tunnel.io.TunnelRecordWriter
30-
3128
import org.apache.spark.{SparkContext, TaskContext}
3229
import org.apache.spark.aliyun.utils.OdpsUtils
3330
import org.apache.spark.api.java.JavaRDD
3431
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3}
3532
import org.apache.spark.internal.Logging
3633
import org.apache.spark.rdd.RDD
3734
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
38-
import org.apache.spark.sql.types.{Decimal, StructType}
35+
import org.apache.spark.sql.types.StructType
3936

4037
class OdpsOps(@transient sc: SparkContext, accessKeyId: String,
4138
accessKeySecret: String, odpsUrl: String, tunnelUrl: String)
@@ -706,7 +703,7 @@ class OdpsOps(@transient sc: SparkContext, accessKeyId: String,
706703

707704
StructType(
708705
columns.map(idx => {
709-
odpsUtils.getCatalystType(tableSchema(idx)._1, tableSchema(idx)._2, true)
706+
OdpsUtils.getCatalystType(tableSchema(idx)._1, tableSchema(idx)._2, true)
710707
})
711708
)
712709
}
@@ -716,46 +713,7 @@ class OdpsOps(@transient sc: SparkContext, accessKeyId: String,
716713
cols.sorted.map { idx =>
717714
val col = schema.getColumn(idx)
718715
try {
719-
col.getTypeInfo.getOdpsType match {
720-
case OdpsType.BIGINT => record.toArray.apply(idx).asInstanceOf[Long]
721-
case OdpsType.BINARY =>
722-
record.toArray.apply(idx).asInstanceOf[com.aliyun.odps.data.Binary].data()
723-
case OdpsType.BOOLEAN => record.toArray.apply(idx).asInstanceOf[Boolean]
724-
case OdpsType.CHAR => record.toArray.apply(idx).asInstanceOf[String]
725-
case OdpsType.DATE => record.toArray.apply(idx).asInstanceOf[java.sql.Date].getTime
726-
case OdpsType.DATETIME =>
727-
val r = record.toArray.apply(idx).asInstanceOf[java.util.Date]
728-
if (r != null) {
729-
new java.sql.Date(r.getTime)
730-
} else null
731-
case OdpsType.DECIMAL =>
732-
new Decimal().set(record.toArray.apply(idx).asInstanceOf[java.math.BigDecimal])
733-
case OdpsType.DOUBLE => record.toArray.apply(idx).asInstanceOf[Double]
734-
case OdpsType.FLOAT => record.toArray.apply(idx).asInstanceOf[Float]
735-
case OdpsType.INT => record.toArray.apply(idx).asInstanceOf[Integer]
736-
case OdpsType.SMALLINT => record.toArray.apply(idx).asInstanceOf[Short]
737-
case OdpsType.STRING => record.getString(idx)
738-
case OdpsType.TIMESTAMP =>
739-
val r = record.toArray.apply(idx).asInstanceOf[java.sql.Timestamp]
740-
if (r != null) {
741-
r
742-
} else null
743-
case OdpsType.TINYINT => record.toArray.apply(idx).asInstanceOf[Byte]
744-
case OdpsType.VARCHAR =>
745-
record.toArray.apply(idx).asInstanceOf[com.aliyun.odps.data.Varchar].getValue
746-
case OdpsType.VOID => "null"
747-
case OdpsType.INTERVAL_DAY_TIME =>
748-
throw new SQLException(s"Unsupported type 'INTERVAL_DAY_TIME'")
749-
case OdpsType.INTERVAL_YEAR_MONTH =>
750-
throw new SQLException(s"Unsupported type 'INTERVAL_YEAR_MONTH'")
751-
case OdpsType.MAP =>
752-
throw new SQLException(s"Unsupported type 'MAP'")
753-
case OdpsType.STRUCT =>
754-
throw new SQLException(s"Unsupported type 'STRUCT'")
755-
case OdpsType.ARRAY =>
756-
throw new SQLException(s"Unsupported type 'ARRAY'")
757-
case _ => throw new SQLException(s"Unsupported type ${col.getTypeInfo.getOdpsType}")
758-
}
716+
OdpsUtils.odpsData2SparkData(col.getTypeInfo, false)(record.get(idx))
759717
} catch {
760718
case e: Exception =>
761719
log.error(s"Can not transfer record column value, idx: $idx, " +
@@ -779,7 +737,7 @@ class OdpsOps(@transient sc: SparkContext, accessKeyId: String,
779737
val schema = odps.tables().get(table).getSchema
780738
val idx = schema.getColumnIndex(name)
781739
val colType = schema.getColumn(name).getTypeInfo
782-
val field = odpsUtils.getCatalystType(name, colType, true)
740+
val field = OdpsUtils.getCatalystType(name, colType, true)
783741

784742
(idx.toString, field.dataType.simpleString)
785743
}
@@ -791,7 +749,7 @@ class OdpsOps(@transient sc: SparkContext, accessKeyId: String,
791749
val column = schema.getColumn(idx)
792750
val name = column.getName
793751
val colType = schema.getColumn(name).getTypeInfo
794-
val field = odpsUtils.getCatalystType(name, colType, true)
752+
val field = OdpsUtils.getCatalystType(name, colType, true)
795753

796754
(name, field.dataType.simpleString)
797755
}

emr-maxcompute/src/main/scala/org/apache/spark/aliyun/odps/datasource/ODPSRDD.scala

Lines changed: 6 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,18 @@
1717
package org.apache.spark.aliyun.odps.datasource
1818

1919
import java.io.EOFException
20-
import java.sql.{Date, SQLException}
21-
20+
import scala.collection.JavaConverters._
2221
import scala.collection.mutable.ArrayBuffer
23-
2422
import com.aliyun.odps.{Odps, PartitionSpec}
2523
import com.aliyun.odps.account.AliyunAccount
2624
import com.aliyun.odps.tunnel.TableTunnel
27-
import com.aliyun.odps.tunnel.io.TunnelRecordReader
28-
2925
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
3026
import org.apache.spark.aliyun.odps.OdpsPartition
27+
import org.apache.spark.aliyun.utils.OdpsUtils
3128
import org.apache.spark.rdd.RDD
3229
import org.apache.spark.sql.catalyst.InternalRow
3330
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
34-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3531
import org.apache.spark.sql.types._
36-
import org.apache.spark.unsafe.types.UTF8String
3732
import org.apache.spark.util.NextIterator
3833

3934
class ODPSRDD(
@@ -67,6 +62,7 @@ class ODPSRDD(
6762
val parSpec = new PartitionSpec(partitionSpec)
6863
downloadSession = tunnel.createDownloadSession(project, table, parSpec)
6964
}
65+
val typeInfos = downloadSession.getSchema.getColumns.asScala.map(_.getTypeInfo)
7066
val reader = downloadSession.openRecordReader(split.start, split.count)
7167
val inputMetrics = context.taskMetrics.inputMetrics
7268

@@ -84,115 +80,8 @@ class ODPSRDD(
8480
schema.zipWithIndex.foreach {
8581
case (s: StructField, idx: Int) =>
8682
try {
87-
s.dataType match {
88-
case LongType =>
89-
val value = r.getBigint(s.name)
90-
if (value != null) {
91-
mutableRow.setLong(idx, value)
92-
} else {
93-
mutableRow.update(idx, null)
94-
}
95-
case BooleanType =>
96-
val value = r.getBoolean(s.name)
97-
if (value != null) {
98-
mutableRow.setBoolean(idx, value)
99-
} else {
100-
mutableRow.update(idx, null)
101-
}
102-
case DoubleType =>
103-
val value = r.getDouble(s.name)
104-
if (value != null) {
105-
mutableRow.setDouble(idx, value)
106-
} else {
107-
mutableRow.update(idx, null)
108-
}
109-
case ShortType =>
110-
val value = r.get(s.name)
111-
if (value != null) {
112-
mutableRow.setShort(idx, value.asInstanceOf[Short])
113-
} else {
114-
mutableRow.update(idx, null)
115-
}
116-
case ByteType =>
117-
val value = r.get(s.name)
118-
if (value != null) {
119-
mutableRow.setByte(idx, value.asInstanceOf[Byte])
120-
} else {
121-
mutableRow.update(idx, null)
122-
}
123-
case DateType =>
124-
val value = r.get(s.name)
125-
value match {
126-
case date1: java.sql.Date =>
127-
mutableRow.update(idx, DateTimeUtils.fromJavaDate(date1))
128-
case date2: java.util.Date =>
129-
mutableRow.setInt(idx,
130-
DateTimeUtils.fromJavaDate(new Date(date2.getTime)))
131-
case null => mutableRow.update(idx, null)
132-
case _ => throw new SQLException(s"Unknown type" +
133-
s" ${value.getClass.getCanonicalName}")
134-
}
135-
case TimestampType =>
136-
val value = r.get(s.name)
137-
value match {
138-
case timestamp: java.sql.Timestamp =>
139-
mutableRow.setLong(idx, DateTimeUtils.fromJavaTimestamp(timestamp))
140-
case null => mutableRow.update(idx, null)
141-
case _ => throw new SQLException(s"Unknown type" +
142-
s" ${value.getClass.getCanonicalName}")
143-
}
144-
case DecimalType.SYSTEM_DEFAULT =>
145-
val value = r.get(s.name)
146-
if (value != null) {
147-
mutableRow.update(idx,
148-
new Decimal().set(value.asInstanceOf[java.math.BigDecimal]))
149-
} else {
150-
mutableRow.update(idx, null)
151-
}
152-
case FloatType =>
153-
val value = r.get(s.name)
154-
if (value != null) {
155-
mutableRow.update(idx, value.asInstanceOf[Float])
156-
} else {
157-
mutableRow.update(idx, null)
158-
}
159-
case IntegerType =>
160-
val value = r.get(s.name)
161-
value match {
162-
case e: java.lang.Integer =>
163-
mutableRow.update(idx, e.toInt)
164-
case null => mutableRow.update(idx, null)
165-
case _ => throw new SQLException(s"Unknown type" +
166-
s" ${value.getClass.getCanonicalName}")
167-
}
168-
case StringType =>
169-
val value = r.get(s.name)
170-
value match {
171-
case e: com.aliyun.odps.data.Char =>
172-
mutableRow.update(idx, UTF8String.fromString(e.toString))
173-
case e: com.aliyun.odps.data.Varchar =>
174-
mutableRow.update(idx, UTF8String.fromString(e.toString))
175-
case e: String =>
176-
mutableRow.update(idx, UTF8String.fromString(e))
177-
case e: Array[Byte] =>
178-
mutableRow.update(idx, UTF8String.fromBytes(e))
179-
case null => mutableRow.update(idx, null)
180-
case _ => throw new SQLException(s"Unknown type" +
181-
s" ${value.getClass.getCanonicalName}")
182-
}
183-
case BinaryType =>
184-
val value = r.get(s.name)
185-
value match {
186-
case e: com.aliyun.odps.data.Binary =>
187-
mutableRow.update(idx, e.data())
188-
case null => mutableRow.update(idx, null)
189-
case _ => throw new SQLException(s"Unknown type" +
190-
s" ${value.getClass.getCanonicalName}")
191-
}
192-
case NullType =>
193-
mutableRow.setNullAt(idx)
194-
case _ => throw new SQLException(s"Unknown type")
195-
}
83+
val value = r.get(s.name)
84+
mutableRow.update(idx, OdpsUtils.odpsData2SparkData(typeInfos(idx))(value))
19685
} catch {
19786
case e: Exception =>
19887
log.error(s"Can not transfer record column value, idx: $idx, " +
@@ -215,7 +104,7 @@ class ODPSRDD(
215104

216105
override def close() {
217106
try {
218-
val totalBytes = reader.asInstanceOf[TunnelRecordReader].getTotalBytes
107+
val totalBytes = reader.getTotalBytes
219108
inputMetrics.incBytesRead(totalBytes)
220109
reader.close()
221110
} catch {

emr-maxcompute/src/main/scala/org/apache/spark/aliyun/odps/datasource/ODPSRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ case class ODPSRelation(
5353
val tableSchema = odpsUtils.getTableSchema(project, table, false)
5454

5555
StructType(
56-
tableSchema.map(e => odpsUtils.getCatalystType(e._1, e._2, true))
56+
tableSchema.map(e => OdpsUtils.getCatalystType(e._1, e._2, true))
5757
)
5858
}
5959

emr-maxcompute/src/main/scala/org/apache/spark/aliyun/odps/datasource/ODPSWriter.scala

Lines changed: 12 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,14 @@
1616
*/
1717
package org.apache.spark.aliyun.odps.datasource
1818

19-
import java.sql.{Date, SQLException}
20-
2119
import com.aliyun.odps._
20+
import com.aliyun.odps.`type`.TypeInfo
2221
import com.aliyun.odps.account.AliyunAccount
23-
import com.aliyun.odps.data.Binary
2422
import com.aliyun.odps.tunnel.TableTunnel
2523
import org.slf4j.LoggerFactory
26-
2724
import org.apache.spark.TaskContext
2825
import org.apache.spark.aliyun.utils.OdpsUtils
2926
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
30-
import org.apache.spark.sql.types._
3127

3228
class ODPSWriter(
3329
accessKeyId: String,
@@ -117,12 +113,8 @@ class ODPSWriter(
117113
}
118114
val uploadId = uploadSession.getId
119115

120-
def writeToFile(schema: Array[(String, OdpsType)], iter: Iterator[Row]) {
121-
val account_ = new AliyunAccount(accessKeyId, accessKeySecret)
122-
val odps_ = new Odps(account_)
123-
odps_.setDefaultProject(project)
124-
odps_.setEndpoint(odpsUrl)
125-
val tunnel_ = new TableTunnel(odps_)
116+
def writeToFile(odps: Odps, schema: Array[(String, TypeInfo)], iter: Iterator[Row]) {
117+
val tunnel_ = new TableTunnel(odps)
126118
tunnel_.setEndpoint(tunnelUrl)
127119
val uploadSession_ = if (isPartitionTable) {
128120
val parSpec = new PartitionSpec(partitionSpec)
@@ -140,52 +132,9 @@ class ODPSWriter(
140132
val record = uploadSession_.newRecord()
141133

142134
schema.zipWithIndex.foreach {
143-
case (s: (String, OdpsType), idx: Int) =>
135+
case (s: (String, TypeInfo), idx: Int) =>
144136
try {
145-
s._2 match {
146-
case OdpsType.BIGINT =>
147-
record.setBigint(s._1, value.get(idx).toString.toLong)
148-
case OdpsType.BINARY =>
149-
record.set(s._1, new Binary(value.getAs[Array[Byte]](idx)))
150-
case OdpsType.BOOLEAN =>
151-
record.setBoolean(s._1, value.getBoolean(idx))
152-
case OdpsType.CHAR =>
153-
record.set(s._1, new com.aliyun.odps.data.Char(value.get(idx).toString))
154-
case OdpsType.DATE =>
155-
record.set(s._1, value.getAs[Date](idx))
156-
case OdpsType.DATETIME =>
157-
record.set(s._1, new java.util.Date(value.getAs[Date](idx).getTime))
158-
case OdpsType.DECIMAL =>
159-
record.set(s._1, Decimal(value.get(idx).toString).toJavaBigDecimal)
160-
case OdpsType.DOUBLE =>
161-
record.setDouble(s._1, value.getDouble(idx))
162-
case OdpsType.FLOAT =>
163-
record.set(s._1, value.get(idx).toString.toFloat)
164-
case OdpsType.INT =>
165-
record.set(s._1, value.get(idx).toString.toInt)
166-
case OdpsType.SMALLINT =>
167-
record.set(s._1, value.get(idx).toString.toShort)
168-
case OdpsType.STRING =>
169-
record.setString(s._1, value.get(idx).toString)
170-
case OdpsType.TINYINT =>
171-
record.set(s._1, value.getAs[Byte](idx))
172-
case OdpsType.VARCHAR =>
173-
record.set(s._1, new com.aliyun.odps.data.Varchar(value.get(idx).toString))
174-
case OdpsType.TIMESTAMP =>
175-
record.setDatetime(s._1, value.getAs[java.sql.Timestamp](idx))
176-
case OdpsType.VOID => record.set(s._1, null)
177-
case OdpsType.INTERVAL_DAY_TIME =>
178-
throw new SQLException(s"Unsupported type 'INTERVAL_DAY_TIME'")
179-
case OdpsType.INTERVAL_YEAR_MONTH =>
180-
throw new SQLException(s"Unsupported type 'INTERVAL_YEAR_MONTH'")
181-
case OdpsType.MAP =>
182-
throw new SQLException(s"Unsupported type 'MAP'")
183-
case OdpsType.STRUCT =>
184-
throw new SQLException(s"Unsupported type 'STRUCT'")
185-
case OdpsType.ARRAY =>
186-
throw new SQLException(s"Unsupported type 'ARRAY'")
187-
case _ => throw new SQLException(s"Unsupported type ${s._2}")
188-
}
137+
record.set(s._1, OdpsUtils.sparkData2OdpsData(s._2)(value.get(idx).asInstanceOf[Object]))
189138
} catch {
190139
case e: NullPointerException =>
191140
if (value.get(idx) == null) {
@@ -202,11 +151,15 @@ class ODPSWriter(
202151
writer.close()
203152
}
204153

205-
val dataSchema = odpsUtils.getTableSchema(project, table, false)
206-
.map{ e => (e._1, e._2.getOdpsType) }
207154
data.foreachPartition {
208155
iterator =>
209-
writeToFile(dataSchema, iterator)
156+
val account_ = new AliyunAccount(accessKeyId, accessKeySecret)
157+
val odps = new Odps(account_)
158+
odps.setDefaultProject(project)
159+
odps.setEndpoint(odpsUrl)
160+
val odpsUtils = new OdpsUtils(odps)
161+
val dataSchema = odpsUtils.getTableSchema(project, table, false)
162+
writeToFile(odps, dataSchema, iterator)
210163
}
211164
val arr = Array.tabulate(data.rdd.partitions.length)(l => Long.box(l))
212165
uploadSession.commit(arr)

0 commit comments

Comments
 (0)